Example #1
0
def build_attr_train_loader(cfg):
    train_items = list()
    attr_dict = None
    for d in cfg.DATASETS.NAMES:
        dataset = DATASET_REGISTRY.get(d)(root=_root,
                                          combineall=cfg.DATASETS.COMBINEALL)
        if comm.is_main_process():
            dataset.show_train()
        if attr_dict is not None:
            assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones"
        else:
            attr_dict = dataset.attr_dict
        train_items.extend(dataset.train)

    train_transforms = build_transforms(cfg, is_train=True)
    train_set = AttrDataset(train_items, train_transforms, attr_dict)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()

    data_sampler = samplers.TrainingSampler(len(train_set))
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        data_sampler, mini_batch_size, True)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return train_loader
Example #2
0
def build_reid_train_loader(cfg):
    train_transforms = build_transforms(cfg, is_train=True)

    logger = logging.getLogger(__name__)
    train_items = list()
    for d in cfg.DATASETS.NAMES:
        logger.info('prepare training set {}'.format(d))
        print('preparing training set...')
        dataset = DATASET_REGISTRY.get(d)(cfg)
        print(dataset)
        train_items.extend(dataset.train)
       

    train_set = BlackreidDataset(train_items, train_transforms, mode='train', relabel=True)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    batch_size = cfg.SOLVER.IMS_PER_BATCH
    num_instance = cfg.DATALOADER.NUM_INSTANCE

    if cfg.DATALOADER.PK_SAMPLER:
        data_sampler = samplers.RandomIdentitySampler(train_set.img_items, batch_size, num_instance)
    else:
        data_sampler = samplers.TrainingSampler(len(train_set))
    batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, batch_size, True)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
    )
    return data_prefetcher(cfg, train_loader)
Example #3
0
def build_attr_test_loader(cfg, dataset_name):
    cfg = cfg.clone()
    cfg.defrost()

    dataset = DATASET_REGISTRY.get(dataset_name)(
        root=_root, combineall=cfg.DATASETS.COMBINEALL)
    if comm.is_main_process():
        dataset.show_test()
    test_items = dataset.test

    test_transforms = build_transforms(cfg, is_train=False)
    test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms)

    mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=4,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader
Example #4
0
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_reid_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        logger = logging.getLogger("fastreid.clas_dataset")
        logger.info("Prepare training set")

        train_items = list()
        for d in cfg.DATASETS.NAMES:
            data = DATASET_REGISTRY.get(d)(root=_root)
            if comm.is_main_process():
                data.show_train()
            train_items.extend(data.train)

        transforms = build_transforms(cfg, is_train=True)
        train_set = ClasDataset(train_items, transforms)

        data_loader = build_reid_train_loader(cfg, train_set=train_set)

        # Save index to class dictionary
        output_dir = cfg.OUTPUT_DIR
        if comm.is_main_process() and output_dir:
            path = os.path.join(output_dir, "idx2class.json")
            with PathManager.open(path, "w") as f:
                json.dump(train_set.idx_to_class, f)

        return data_loader
Example #5
0
 def build_train_loader(cls, cfg):
     path_imgrec = cfg.DATASETS.REC_PATH
     if path_imgrec != "":
         transforms = build_transforms(cfg, is_train=True)
         train_set = MXFaceDataset(path_imgrec, transforms)
         return build_reid_train_loader(cfg, train_set=train_set)
     else:
         return DefaultTrainer.build_train_loader(cfg)
Example #6
0
    def build_test_loader(cls, cfg, dataset_name):
        dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
        attr_dict = dataset.attr_dict
        if comm.is_main_process():
            dataset.show_test()
        test_items = dataset.test

        test_transforms = build_transforms(cfg, is_train=False)
        test_set = AttrDataset(test_items, test_transforms, attr_dict)
        data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
        return data_loader
Example #7
0
    def build_test_loader(cls, cfg, dataset_name):
        """
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_reid_test_loader`.
        Overwrite it if you'd like a different data loader.
        """

        data = DATASET_REGISTRY.get(dataset_name)(root=_root)
        if comm.is_main_process():
            data.show_test()
        transforms = build_transforms(cfg, is_train=False)
        test_set = ClasDataset(data.query, transforms)
        data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
        return data_loader
Example #8
0
    def build_dataset(cls,
                      cfg,
                      img_items,
                      is_train=False,
                      relabel=False,
                      transforms=None,
                      with_mem_idx=False):
        if transforms is None:
            transforms = build_transforms(cfg, is_train=is_train)

        if with_mem_idx:
            sorted_img_items = sorted(img_items)
            for i in range(len(sorted_img_items)):
                sorted_img_items[i] += (i, )
            return InMemoryDataset(sorted_img_items, transforms, relabel)
        else:
            return CommDataset(img_items, transforms, relabel)
Example #9
0
    def build_train_loader(cls, cfg):

        logger = logging.getLogger("fastreid.attr_dataset")
        train_items = list()
        attr_dict = None
        for d in cfg.DATASETS.NAMES:
            dataset = DATASET_REGISTRY.get(d)(
                root=_root, combineall=cfg.DATASETS.COMBINEALL)
            if comm.is_main_process():
                dataset.show_train()
            if attr_dict is not None:
                assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones"
            else:
                attr_dict = dataset.attr_dict
            train_items.extend(dataset.train)

        train_transforms = build_transforms(cfg, is_train=True)
        train_set = AttrDataset(train_items, train_transforms, attr_dict)

        data_loader = build_reid_train_loader(cfg, train_set=train_set)
        AttrTrainer.sample_weights = data_loader.dataset.sample_weights
        return data_loader
Example #10
0
def build_reid_test_loader(cfg, dataset_name):
    test_transforms = build_transforms(cfg, is_train=False)

    logger = logging.getLogger(__name__)
    logger.info('prepare test set {}'.format(dataset_name))
    print('preparing test set...')
    dataset = DATASET_REGISTRY.get(dataset_name)(cfg)
    print(dataset)
    test_items = dataset.query + dataset.gallery

    test_set = BlackreidDataset(test_items, test_transforms, mode='test', relabel=False)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    batch_size = cfg.TEST.IMS_PER_BATCH
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler, batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=fast_batch_collator)
    return data_prefetcher(cfg, test_loader), len(dataset.query)
Example #11
0
    def build_train_loader(cls, cfg):
        """
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_reid_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        logger = logging.getLogger("fastreid.clas_dataset")
        logger.info("Prepare training set")

        train_items = list()
        for d in cfg.DATASETS.NAMES:
            data = DATASET_REGISTRY.get(d)(root=_root)
            if comm.is_main_process():
                data.show_train()
            train_items.extend(data.train)
        transforms = build_transforms(cfg, is_train=True)
        train_set = ClasDataset(train_items, transforms)
        cls.idx2class = train_set.idx_to_class

        data_loader = build_reid_train_loader(cfg, train_set=train_set)
        return data_loader
Example #12
0
    def eval_ARI_purity(cls, cfg, model, transforms=None, **kwargs):
        num_devices = comm.get_world_size()
        logger = logging.getLogger('fastreid')

        _root = os.getenv("FASTREID_DATASETS", "/root/datasets")
        evaluator = ARI_Purity_Evaluator(cfg)
        results = OrderedDict()

        if transforms is None:
            transforms = build_transforms(cfg, is_train=False)

        total_datasets = set(cfg.DATASETS.TESTS)
        for dataset_name in total_datasets:
            logger.info(f"Starting evaluating ARI on dataset {dataset_name}")

            data = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
            test_items = data.train
            test_set = CommDataset(test_items, transforms, relabel=True)
            data_loader, num_query = build_reid_test_loader(cfg,
                                                            dataset_name=dataset_name,
                                                            test_set=test_set)

            total = len(data_loader)  # inference data loader must have a fixed length
            # print('data_loader len =', total)
            evaluator.reset()

            img_nums = len(test_items)

            num_warmup = min(5, total - 1)
            start_time = time.perf_counter()
            total_compute_time = 0

            with inference_context(model), torch.no_grad():
                for idx, inputs in enumerate(data_loader):
                    # print(inputs)
                    if idx == num_warmup:
                        start_time = time.perf_counter()
                        total_compute_time = 0

                    start_compute_time = time.perf_counter()
                    outputs = model(inputs)

                    # Flip test
                    if cfg.TEST.FLIP.ENABLED:
                        inputs["images"] = inputs["images"].flip(dims=[3])
                        flip_outputs = model(inputs)
                        outputs = (outputs + flip_outputs) / 2

                    if torch.cuda.is_available():
                        torch.cuda.synchronize()
                    total_compute_time += time.perf_counter() - start_compute_time

                    evaluator.process(inputs, outputs)

                    iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
                    seconds_per_batch = total_compute_time / iters_after_start
                    if idx >= num_warmup * 2 or seconds_per_batch > 30:
                        total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
                        eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
                        log_every_n_seconds(
                            logging.INFO,
                            "Inference done {}/{}. {:.4f} s / batch. ETA={}".format(
                                idx + 1, total, seconds_per_batch, str(eta)
                            ),
                            n=30,
                        )

            # Measure the time only for this worker (before the synchronization barrier)
            total_time = time.perf_counter() - start_time
            total_time_str = str(datetime.timedelta(seconds=total_time))
            # NOTE this format is parsed by grep
            logger.info(
                "Total inference time: {} ({:.6f} s / batch per device, on {} devices)".format(
                    total_time_str, total_time / (total - num_warmup), num_devices
                )
            )
            total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
            logger.info(
                "Total inference pure compute time: {} ({:.6f} s / batch per device, on {} devices)".format(
                    total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
                )
            )
            results_i = evaluator.evaluate()
            ARI_score, purity = results_i
            results[f'{dataset_name}/ARI'] = ARI_score
            results[f'{dataset_name}/purity'] = purity

            if comm.is_main_process():
                assert isinstance(
                    results, dict
                ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                    results
                )
                logger.info(f"ARI score for {dataset_name} is {ARI_score:.4f}")
                logger.info(f"Purity for {dataset_name} is {purity:.4f}")

        return results