Пример #1
0
    def __init__(self, data_source, mini_batch_size:int, num_instances:int=4, seed: Optional[int] = None, with_mem_idx=False):
        self.data_source = data_source
        self.num_instances = num_instances
        self.num_pids_per_batch = mini_batch_size // self.num_instances
        self.with_mem_idx = with_mem_idx

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
        self.batch_size = mini_batch_size * self._world_size

        self.index_pid = defaultdict(int)
        self.pid_cam = defaultdict(list)
        self.pid_index = defaultdict(list)

        for index, info in enumerate(data_source):
            pid = info[1]
            camid = info[2]
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)

        self.pids = sorted(list(self.pid_index.keys()))
        self.num_identities = len(self.pids)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
Пример #2
0
    def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None):
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_instances = num_instances
        self.num_pids_per_batch = batch_size // self.num_instances

        self.index_pid = defaultdict(list)
        self.pid_cam = defaultdict(list)
        self.pid_index = defaultdict(list)

        for index, info in enumerate(data_source):
            pid = info[1]
            # camid = info[2]
            camid = info[3]['domains']
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)

        self.pids = list(self.pid_index.keys())
        self.num_identities = len(self.pids)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
Пример #3
0
    def __init__(self, data_source: str, batch_size: int, num_instances: int, delete_rem: bool, seed: Optional[int] = None, cfg = None):
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_instances = num_instances
        self.num_pids_per_batch = batch_size // self.num_instances
        self.delete_rem = delete_rem

        self.index_pid = defaultdict(list)
        self.pid_cam = defaultdict(list)
        self.pid_index = defaultdict(list)

        for index, info in enumerate(data_source):
            pid = info[1]
            camid = info[2]
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)

        self.pids = list(self.pid_index.keys())
        self.num_identities = len(self.pids)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()


        val_pid_index = [len(x) for x in self.pid_index.values()]
        min_v = min(val_pid_index)
        max_v = max(val_pid_index)
        hist_pid_index = [val_pid_index.count(x) for x in range(min_v, max_v+1)]
        num_print = 5
        for i, x in enumerate(range(min_v, min_v+min(len(hist_pid_index), num_print))):
            print('dataset histogram [bin:{}, cnt:{}]'.format(x, hist_pid_index[i]))
        print('...')
        print('dataset histogram [bin:{}, cnt:{}]'.format(max_v, val_pid_index.count(max_v)))

        val_pid_index_upper = []
        for x in val_pid_index:
            v_remain = x % self.num_instances
            if v_remain == 0:
                val_pid_index_upper.append(x)
            else:
                if self.delete_rem:
                    if x < self.num_instances:
                        val_pid_index_upper.append(x - v_remain + self.num_instances)
                    else:
                        val_pid_index_upper.append(x - v_remain)
                else:
                    val_pid_index_upper.append(x - v_remain + self.num_instances)

        total_images = sum(val_pid_index_upper)
        total_images = total_images - (total_images % self.batch_size) - self.batch_size # approax
        self.total_images = total_images
    def __init__(self,
                 data_source: str,
                 mini_batch_size: int,
                 num_instances: int,
                 set_weight: list,
                 seed: Optional[int] = None):
        self.data_source = data_source
        self.num_instances = num_instances
        self.num_pids_per_batch = mini_batch_size // self.num_instances

        self.set_weight = set_weight

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
        self.batch_size = mini_batch_size * self._world_size

        assert self.batch_size % (sum(self.set_weight) * self.num_instances) == 0 and \
               self.batch_size > sum(
            self.set_weight) * self.num_instances, "Batch size must be divisible by the sum set weight"

        self.index_pid = dict()
        self.pid_cam = defaultdict(list)
        self.pid_index = defaultdict(list)

        self.cam_pid = defaultdict(list)

        for index, info in enumerate(data_source):
            pid = info[1]
            camid = info[2]
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)
            self.cam_pid[camid].append(pid)

        # Get sampler prob for each cam
        self.set_pid_prob = defaultdict(list)
        for camid, pid_list in self.cam_pid.items():
            index_per_pid = []
            for pid in pid_list:
                index_per_pid.append(len(self.pid_index[pid]))
            cam_image_number = sum(index_per_pid)
            prob = [i / cam_image_number for i in index_per_pid]
            self.set_pid_prob[camid] = prob

        self.pids = sorted(list(self.pid_index.keys()))
        self.num_identities = len(self.pids)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
Пример #5
0
    def __init__(self,
                 size: int,
                 shuffle: bool = True,
                 seed: Optional[int] = None):
        """
        Args:
            size (int): the total number of data of the underlying dataset to sample from
            shuffle (bool): whether to shuffle the indices or not
            seed (int): the initial seed of the shuffle. Must be the same
                across all workers. If None, will use a random seed shared
                among workers (require synchronization among all workers).
        """
        self._size = size
        assert size > 0
        self._shuffle = shuffle
        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
Пример #6
0
    def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None,
                 callback_get_label: Callable = None):
        self.data_source = data_source
        # consider all elements in the dataset
        self.indices = list(range(len(data_source)))
        # if num_samples is not provided, draw `len(indices)` samples in each iteration
        self._size = len(self.indices) if size is None else size
        self.callback_get_label = callback_get_label

        # distribution of classes in the dataset
        label_to_count = {}
        for idx in self.indices:
            label = self._get_label(data_source, idx)
            label_to_count[label] = label_to_count.get(label, 0) + 1

        # weight for each sample
        weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices]
        self.weights = torch.DoubleTensor(weights)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)
        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
Пример #7
0
    def __init__(self, data_source: str, batch_size: int, num_instances: int, delete_rem: bool, seed: Optional[int] = None, cfg = None):
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_instances = num_instances
        self.num_pids_per_batch = batch_size // self.num_instances
        self.delete_rem = delete_rem

        self.index_pid = defaultdict(list)
        self.pid_domain = defaultdict(list)
        self.pid_index = defaultdict(list)

        for index, info in enumerate(data_source):

            domainid = info[3]['domains']
            if cfg.DATALOADER.CAMERA_TO_DOMAIN:
                pid = info[1] + str(domainid)
            else:
                pid = info[1]
            self.index_pid[index] = pid
            # self.pid_domain[pid].append(domainid)
            self.pid_domain[pid] = domainid
            self.pid_index[pid].append(index)

        self.pids = list(self.pid_index.keys())
        self.domains = list(self.pid_domain.values())

        self.num_identities = len(self.pids)
        self.num_domains = len(set(self.domains))

        self.batch_size //= self.num_domains
        self.num_pids_per_batch //= self.num_domains

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()


        val_pid_index = [len(x) for x in self.pid_index.values()]
        min_v = min(val_pid_index)
        max_v = max(val_pid_index)
        hist_pid_index = [val_pid_index.count(x) for x in range(min_v, max_v+1)]
        num_print = 5
        for i, x in enumerate(range(min_v, min_v+min(len(hist_pid_index), num_print))):
            print('dataset histogram [bin:{}, cnt:{}]'.format(x, hist_pid_index[i]))
        print('...')
        print('dataset histogram [bin:{}, cnt:{}]'.format(max_v, val_pid_index.count(max_v)))

        val_pid_index_upper = []
        for x in val_pid_index:
            v_remain = x % self.num_instances
            if v_remain == 0:
                val_pid_index_upper.append(x)
            else:
                if self.delete_rem:
                    if x < self.num_instances:
                        val_pid_index_upper.append(x - v_remain + self.num_instances)
                    else:
                        val_pid_index_upper.append(x - v_remain)
                else:
                    val_pid_index_upper.append(x - v_remain + self.num_instances)

        cnt_domains = [0 for x in range(self.num_domains)]
        for val, index in zip(val_pid_index_upper, self.domains):
            cnt_domains[index] += val
        self.max_cnt_domains = max(cnt_domains)
        self.total_images = self.num_domains * (self.max_cnt_domains - (self.max_cnt_domains % self.batch_size) - self.batch_size)
Пример #8
0
def build_reid_train_loader(cfg):

    # build datasets
    cfg = cfg.clone()
    frozen = cfg.is_frozen()
    cfg.defrost()

    individual_flag_ori = cfg.DATALOADER.INDIVIDUAL
    individual_flag_meta = cfg.META.DATA.INDIVIDUAL
    if cfg.META.DATA.NAMES == "":
        individual_flag_meta = False
    gettrace = getattr(sys, 'gettrace', None)
    if gettrace():
        print('*' * 100)
        print('Hmm, Big Debugger is watching me')
        print('*' * 100)
        num_workers = 0
    else:
        num_workers = cfg.DATALOADER.NUM_WORKERS

    # transforms
    train_transforms = build_transforms(cfg, is_train=True, is_fake=False)
    if (cfg.META.DATA.NAMES != "") and \
            (cfg.META.DATA.LOADER_FLAG == 'synth' or cfg.META.DATA.SYNTH_FLAG is not 'none'):
        synth_transforms = build_transforms(cfg, is_train=True, is_fake=True)
        cfg.META.DATA.LOADER_FLAG = 'each'
    else:
        synth_transforms = None
    train_set_all = []
    train_items = list()
    domain_idx = 0
    camera_all = list()

    # load datasets
    for d in cfg.DATASETS.NAMES:
        dataset = DATASET_REGISTRY.get(d)(root=_root,
                                          combineall=cfg.DATASETS.COMBINEALL)
        if comm.is_main_process():
            dataset.show_train()
        if len(dataset.train[0]) < 4:
            for i, x in enumerate(dataset.train):
                add_info = {}  # dictionary

                if cfg.DATALOADER.CAMERA_TO_DOMAIN:
                    add_info['domains'] = dataset.train[i][2]
                    camera_all.append(dataset.train[i][2])
                else:
                    add_info['domains'] = int(domain_idx)
                dataset.train[i] = list(dataset.train[i])
                dataset.train[i].append(add_info)
                dataset.train[i] = tuple(dataset.train[i])
        domain_idx += 1
        train_items.extend(dataset.train)
        if individual_flag_ori or individual_flag_meta:  # individual set
            train_set_all.append(dataset.train)

    if cfg.DATALOADER.CAMERA_TO_DOMAIN:  # used for single-source DG
        num_domains = len(set(camera_all))
    else:
        num_domains = domain_idx
    cfg.META.DATA.NUM_DOMAINS = num_domains

    if cfg.DATALOADER.NAIVE_WAY:
        logger.info('**[dataloader info: random domain shuffle]**')
    else:
        logger.info('**[dataloader info: uniform domain]**')
        logger.info(
            '**[The batch size should be a multiple of the number of domains.]**'
        )
        assert (cfg.SOLVER.IMS_PER_BATCH % (num_domains*cfg.DATALOADER.NUM_INSTANCE) == 0), \
            "cfg.SOLVER.IMS_PER_BATCH should be a multiple of (num_domain x num_instance)"
        assert (cfg.META.DATA.MTRAIN_MINI_BATCH % (num_domains*cfg.META.DATA.MTRAIN_NUM_INSTANCE) == 0), \
            "cfg.META.DATA.MTRAIN_MINI_BATCH should be a multiple of (num_domain x num_instance)"
        assert (cfg.META.DATA.MTEST_MINI_BATCH % (num_domains*cfg.META.DATA.MTEST_NUM_INSTANCE) == 0), \
            "cfg.META.DATA.MTEST_MINI_BATCH should be a multiple of (num_domain x num_instance)"

    if individual_flag_ori:
        cfg.SOLVER.IMS_PER_BATCH //= num_domains
    if individual_flag_meta:
        cfg.META.DATA.MTRAIN_MINI_BATCH //= num_domains
        cfg.META.DATA.MTEST_MINI_BATCH //= num_domains

    if 'keypoint' in cfg.META.DATA.NAMES:  # used for keypoint (not used in MetaBIN)
        cfg, train_set_all = make_keypoint_data(cfg=cfg,
                                                data_name=cfg.META.DATA.NAMES,
                                                train_items=train_items)

    train_set = CommDataset(train_items, train_transforms, relabel=True)
    if (synth_transforms
            is not None) and (cfg.META.DATA.NAMES !=
                              ""):  # used for synthetic (not used in MetaBIN)
        synth_set = CommDataset(train_items, synth_transforms, relabel=True)

    if individual_flag_ori or individual_flag_meta:  # for individual dataloader
        relabel_flag = False
        if individual_flag_meta:
            relabel_flag = cfg.META.DATA.RELABEL

        for i, x in enumerate(train_set_all):
            train_set_all[i] = CommDataset(x,
                                           train_transforms,
                                           relabel=relabel_flag)
            if not relabel_flag:
                train_set_all[i].relabel = True
                train_set_all[i].pid_dict = train_set.pid_dict
        # Check number of data
        cnt_data = 0
        for x in train_set_all:
            cnt_data += len(x.img_items)
        if cnt_data != len(train_set.img_items):
            print("data loading error, check build.py")

    if individual_flag_ori:  # for individual dataloader (domain-wise)
        train_loader = []
        if len(train_set_all) > 0:
            for i, x in enumerate(train_set_all):
                train_loader.append(
                    make_sampler(train_set=x,
                                 num_batch=cfg.SOLVER.IMS_PER_BATCH,
                                 num_instance=cfg.DATALOADER.NUM_INSTANCE,
                                 num_workers=num_workers,
                                 mini_batch_size=cfg.SOLVER.IMS_PER_BATCH //
                                 comm.get_world_size(),
                                 drop_last=cfg.DATALOADER.DROP_LAST,
                                 flag1=cfg.DATALOADER.NAIVE_WAY,
                                 flag2=cfg.DATALOADER.DELETE_REM,
                                 cfg=cfg))
    else:
        train_loader = make_sampler(train_set=train_set,
                                    num_batch=cfg.SOLVER.IMS_PER_BATCH,
                                    num_instance=cfg.DATALOADER.NUM_INSTANCE,
                                    num_workers=num_workers,
                                    mini_batch_size=cfg.SOLVER.IMS_PER_BATCH //
                                    comm.get_world_size(),
                                    drop_last=cfg.DATALOADER.DROP_LAST,
                                    flag1=cfg.DATALOADER.NAIVE_WAY,
                                    flag2=cfg.DATALOADER.DELETE_REM,
                                    cfg=cfg)

    train_loader_add = {}
    train_loader_add['mtrain'] = None  # mtrain dataset
    train_loader_add['mtest'] = None  # mtest dataset
    if cfg.META.DATA.NAMES != "":
        if cfg.META.DATA.LOADER_FLAG == 'each':  # "each": meta-init / meta-train / meta-test
            make_mtrain = True
            make_mtest = True
        elif cfg.META.DATA.LOADER_FLAG == 'diff':  # "diff": meta-init / meta-final
            make_mtrain = True
            make_mtest = False
        elif cfg.META.DATA.LOADER_FLAG == 'same':  # "same": meta-init
            make_mtrain = False
            make_mtest = False
        else:
            print('error in cfg.META.DATA.LOADER_FLAG')

        train_loader_add['mtrain'] = [] if make_mtrain else None
        train_loader_add['mtest'] = [] if make_mtest else None

        if cfg.META.DATA.SYNTH_SAME_SEED:
            seed = comm.shared_random_seed()
        else:
            seed = None

        if individual_flag_meta:  # for individual dataset (domain-wise)
            for i, x in enumerate(train_set_all):
                if make_mtrain:
                    train_loader_add['mtrain'].append(
                        make_sampler(
                            train_set=x,
                            num_batch=cfg.META.DATA.MTRAIN_MINI_BATCH,
                            num_instance=cfg.META.DATA.MTRAIN_NUM_INSTANCE,
                            num_workers=num_workers,
                            mini_batch_size=cfg.META.DATA.MTRAIN_MINI_BATCH //
                            comm.get_world_size(),
                            drop_last=cfg.META.DATA.DROP_LAST,
                            flag1=cfg.META.DATA.NAIVE_WAY,
                            flag2=cfg.META.DATA.DELETE_REM,
                            seed=seed,
                            cfg=cfg))
                if make_mtest:
                    train_loader_add['mtest'].append(
                        make_sampler(
                            train_set=x,
                            num_batch=cfg.META.DATA.MTEST_MINI_BATCH,
                            num_instance=cfg.META.DATA.MTEST_NUM_INSTANCE,
                            num_workers=num_workers,
                            mini_batch_size=cfg.META.DATA.MTEST_MINI_BATCH //
                            comm.get_world_size(),
                            drop_last=cfg.META.DATA.DROP_LAST,
                            flag1=cfg.META.DATA.NAIVE_WAY,
                            flag2=cfg.META.DATA.DELETE_REM,
                            seed=seed,
                            cfg=cfg))
        else:
            if make_mtrain:  # meta train dataset
                train_loader_add['mtrain'] = make_sampler(
                    train_set=train_set,
                    num_batch=cfg.META.DATA.MTRAIN_MINI_BATCH,
                    num_instance=cfg.META.DATA.MTRAIN_NUM_INSTANCE,
                    num_workers=num_workers,
                    mini_batch_size=cfg.META.DATA.MTRAIN_MINI_BATCH //
                    comm.get_world_size(),
                    drop_last=cfg.META.DATA.DROP_LAST,
                    flag1=cfg.META.DATA.NAIVE_WAY,
                    flag2=cfg.META.DATA.DELETE_REM,
                    seed=seed,
                    cfg=cfg)
            if make_mtest:  # meta train dataset
                if synth_transforms is None:
                    train_loader_add['mtest'] = make_sampler(
                        train_set=train_set,
                        num_batch=cfg.META.DATA.MTEST_MINI_BATCH,
                        num_instance=cfg.META.DATA.MTEST_NUM_INSTANCE,
                        num_workers=num_workers,
                        mini_batch_size=cfg.META.DATA.MTEST_MINI_BATCH //
                        comm.get_world_size(),
                        drop_last=cfg.META.DATA.DROP_LAST,
                        flag1=cfg.META.DATA.NAIVE_WAY,
                        flag2=cfg.META.DATA.DELETE_REM,
                        seed=seed,
                        cfg=cfg)
                else:
                    train_loader_add['mtest'] = make_sampler(
                        train_set=synth_set,
                        num_batch=cfg.META.DATA.MTEST_MINI_BATCH,
                        num_instance=cfg.META.DATA.MTEST_NUM_INSTANCE,
                        num_workers=num_workers,
                        mini_batch_size=cfg.META.DATA.MTEST_MINI_BATCH //
                        comm.get_world_size(),
                        drop_last=cfg.META.DATA.DROP_LAST,
                        flag1=cfg.META.DATA.NAIVE_WAY,
                        flag2=cfg.META.DATA.DELETE_REM,
                        seed=seed,
                        cfg=cfg)

        if frozen: cfg.freeze()

    return train_loader, train_loader_add, cfg