Ejemplo n.º 1
0
def build_attr_train_loader(cfg):
    train_items = list()
    attr_dict = None
    for d in cfg.DATASETS.NAMES:
        dataset = DATASET_REGISTRY.get(d)(root=_root,
                                          combineall=cfg.DATASETS.COMBINEALL)
        if comm.is_main_process():
            dataset.show_train()
        if attr_dict is not None:
            assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones"
        else:
            attr_dict = dataset.attr_dict
        train_items.extend(dataset.train)

    train_transforms = build_transforms(cfg, is_train=True)
    train_set = AttrDataset(train_items, train_transforms, attr_dict)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()

    data_sampler = samplers.TrainingSampler(len(train_set))
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        data_sampler, mini_batch_size, True)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return train_loader
Ejemplo n.º 2
0
    def build_train_loader(cls,
                           cfg,
                           train_set=None,
                           sampler=None,
                           with_mem_idx=False):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare training loader")

        total_batch_size = cfg.SOLVER.IMS_PER_BATCH
        mini_batch_size = total_batch_size // comm.get_world_size()

        if sampler is None:
            num_instance = cfg.DATALOADER.NUM_INSTANCE
            sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
            logger.info("Using training sampler {}".format(sampler_name))

            if sampler_name == "TrainingSampler":
                sampler = samplers.TrainingSampler(len(train_set))
            elif sampler_name == "NaiveIdentitySampler":
                sampler = samplers.NaiveIdentitySampler(
                    train_set.img_items, mini_batch_size, num_instance)
            elif sampler_name == "BalancedIdentitySampler":
                sampler = samplers.BalancedIdentitySampler(
                    train_set.img_items, mini_batch_size, num_instance)
            elif sampler_name == "SetReWeightSampler":
                set_weight = cfg.DATALOADER.SET_WEIGHT
                sampler = samplers.SetReWeightSampler(train_set.img_items,
                                                      mini_batch_size,
                                                      num_instance, set_weight)
            elif sampler_name == "ImbalancedDatasetSampler":
                sampler = samplers.ImbalancedDatasetSampler(
                    train_set.img_items)
            else:
                raise ValueError(
                    "Unknown training sampler: {}".format(sampler_name))

        iters = cfg.SOLVER.ITERS
        num_workers = cfg.DATALOADER.NUM_WORKERS
        batch_sampler = BatchSampler(sampler, mini_batch_size, True)

        train_loader = IterLoader(
            DataLoader(
                Preprocessor(train_set, with_mem_idx),
                num_workers=num_workers,
                batch_sampler=batch_sampler,
                pin_memory=True,
            ),
            length=iters,
        )
        # train_loader = DataLoaderX(
        #     comm.get_local_rank(),
        #     dataset=Preprocessor(train_set, with_mem_idx),
        #     num_workers=num_workers,
        #     batch_sampler=batch_sampler,
        #     collate_fn=fast_batch_collator,
        #     pin_memory=True,
        # )

        return train_loader
Ejemplo n.º 3
0
def build_cls_train_loader(cfg, mapper=None, **kwargs):
    cfg = cfg.clone()

    train_items = list()
    for d in cfg.DATASETS.NAMES:
        dataset = DATASET_REGISTRY.get(d)(root=_root, **kwargs)
        if comm.is_main_process():
            dataset.show_train()
        train_items.extend(dataset.train)

    if mapper is not None:
        transforms = mapper
    else:
        transforms = build_transforms(cfg, is_train=True)

    train_set = CommDataset(train_items, transforms, relabel=False)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    num_instance = cfg.DATALOADER.NUM_INSTANCE
    mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()

    if cfg.DATALOADER.PK_SAMPLER:
        if cfg.DATALOADER.NAIVE_WAY:
            data_sampler = samplers.NaiveIdentitySampler(
                train_set.img_items, mini_batch_size, num_instance)
        else:
            data_sampler = samplers.BalancedIdentitySampler(
                train_set.img_items, mini_batch_size, num_instance)
    else:
        data_sampler = samplers.TrainingSampler(len(train_set))
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        data_sampler, mini_batch_size, True)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return train_loader