Exemple #1
0
    def __init__(self, data_source, mini_batch_size:int, num_instances:int=4, seed: Optional[int] = None, with_mem_idx=False):
        self.data_source = data_source
        self.num_instances = num_instances
        self.num_pids_per_batch = mini_batch_size // self.num_instances
        self.with_mem_idx = with_mem_idx

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
        self.batch_size = mini_batch_size * self._world_size

        self.index_pid = defaultdict(int)
        self.pid_cam = defaultdict(list)
        self.pid_index = defaultdict(list)

        for index, info in enumerate(data_source):
            pid = info[1]
            camid = info[2]
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)

        self.pids = sorted(list(self.pid_index.keys()))
        self.num_identities = len(self.pids)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
    def __init__(self,
                 data_source: str,
                 mini_batch_size: int,
                 num_instances: int,
                 set_weight: list,
                 seed: Optional[int] = None):
        self.data_source = data_source
        self.num_instances = num_instances
        self.num_pids_per_batch = mini_batch_size // self.num_instances

        self.set_weight = set_weight

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
        self.batch_size = mini_batch_size * self._world_size

        assert self.batch_size % (sum(self.set_weight) * self.num_instances) == 0 and \
               self.batch_size > sum(
            self.set_weight) * self.num_instances, "Batch size must be divisible by the sum set weight"

        self.index_pid = dict()
        self.pid_cam = defaultdict(list)
        self.pid_index = defaultdict(list)

        self.cam_pid = defaultdict(list)

        for index, info in enumerate(data_source):
            pid = info[1]
            camid = info[2]
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)
            self.cam_pid[camid].append(pid)

        # Get sampler prob for each cam
        self.set_pid_prob = defaultdict(list)
        for camid, pid_list in self.cam_pid.items():
            index_per_pid = []
            for pid in pid_list:
                index_per_pid.append(len(self.pid_index[pid]))
            cam_image_number = sum(index_per_pid)
            prob = [i / cam_image_number for i in index_per_pid]
            self.set_pid_prob[camid] = prob

        self.pids = sorted(list(self.pid_index.keys()))
        self.num_identities = len(self.pids)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
Exemple #3
0
def build_cls_test_loader(cfg, dataset_name, mapper=None, **kwargs):
    cfg = cfg.clone()

    dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
    if comm.is_main_process():
        dataset.show_test()
    test_items = dataset.query

    if mapper is not None:
        transforms = mapper
    else:
        transforms = build_transforms(cfg, is_train=False)

    test_set = CommDataset(test_items, transforms, relabel=False)

    mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=4,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader
Exemple #4
0
def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4):
    """
    Similar to `build_reid_train_loader`. This sampler coordinates all workers to produce
    the exact set of all samples
    This interface is experimental.

    Args:
        test_set:
        test_batch_size:
        num_query:
        num_workers:

    Returns:
        DataLoader: a torch DataLoader, that loads the given reid dataset, with
        the test-time transformation.

    Examples:
    ::
        data_loader = build_reid_test_loader(test_set, test_batch_size, num_query)
        # or, instantiate with a CfgNode:
        data_loader = build_reid_test_loader(cfg, "my_test")
    """

    mini_batch_size = test_batch_size // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=num_workers,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader, num_query
Exemple #5
0
def build_attr_train_loader(cfg):
    train_items = list()
    attr_dict = None
    for d in cfg.DATASETS.NAMES:
        dataset = DATASET_REGISTRY.get(d)(root=_root,
                                          combineall=cfg.DATASETS.COMBINEALL)
        if comm.is_main_process():
            dataset.show_train()
        if attr_dict is not None:
            assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones"
        else:
            attr_dict = dataset.attr_dict
        train_items.extend(dataset.train)

    train_transforms = build_transforms(cfg, is_train=True)
    train_set = AttrDataset(train_items, train_transforms, attr_dict)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()

    data_sampler = samplers.TrainingSampler(len(train_set))
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        data_sampler, mini_batch_size, True)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return train_loader
Exemple #6
0
def build_reid_train_loader(
    train_set,
    *,
    sampler=None,
    total_batch_size,
    num_workers=0,
):
    """
    Build a dataloader for object re-identification with some default features.
    This interface is experimental.

    Returns:
        torch.utils.data.DataLoader: a dataloader.
    """

    mini_batch_size = total_batch_size // comm.get_world_size()

    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler, mini_batch_size, True)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return train_loader
Exemple #7
0
def build_reid_test_loader(cfg, dataset_name):
    cfg = cfg.clone()
    cfg.defrost()

    dataset = DATASET_REGISTRY.get(dataset_name)(
        root=_root, dataset_name=cfg.SPECIFIC_DATASET)
    if comm.is_main_process():
        dataset.show_test()
    test_items = dataset.query + dataset.gallery

    test_transforms = build_transforms(cfg, is_train=False)
    test_set = CommDataset(test_items, test_transforms, relabel=False)

    mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=0,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader, len(dataset.query)
Exemple #8
0
def build_attr_test_loader(cfg, dataset_name):
    cfg = cfg.clone()
    cfg.defrost()

    dataset = DATASET_REGISTRY.get(dataset_name)(
        root=_root, combineall=cfg.DATASETS.COMBINEALL)
    if comm.is_main_process():
        dataset.show_test()
    test_items = dataset.test

    test_transforms = build_transforms(cfg, is_train=False)
    test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms)

    mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=4,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader
Exemple #9
0
    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            predictions = comm.gather(self._predictions, dst=0)
            predictions = list(itertools.chain(*predictions))

            if not comm.is_main_process():
                return {}
        else:
            predictions = self._predictions

        features = []
        pids = []
        # camids = []
        for prediction in predictions:
            features.append(prediction['feats'])
            pids.append(prediction['pids'])
            # camids.append(prediction['camids'])

        features = torch.cat(features, dim=0)
        pids = torch.cat(pids, dim=0).numpy()

        rerank_dist = compute_jaccard_distance(
            features,
            k1=self.cfg.CLUSTER.JACCARD.K1,
            k2=self.cfg.CLUSTER.JACCARD.K2,
        )
        pseudo_labels = self.cluster.fit_predict(rerank_dist)

        contingency_matrix = metrics.cluster.contingency_matrix(
            pids, pseudo_labels)
        purity = np.sum(np.amax(contingency_matrix,
                                axis=0)) / np.sum(contingency_matrix)
        return purity
Exemple #10
0
    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            predictions = comm.gather(self._predictions, dst=0)
            predictions = list(itertools.chain(*predictions))

            if not comm.is_main_process(): return {}

        else:
            predictions = self._predictions

        pred_logits = []
        labels = []
        for prediction in predictions:
            pred_logits.append(prediction['logits'])
            labels.append(prediction['labels'])

        pred_logits = torch.cat(pred_logits, dim=0)
        labels = torch.cat(labels, dim=0)

        # measure accuracy and record loss
        acc1, = accuracy(pred_logits, labels, topk=(1, ))

        self._results = OrderedDict()
        self._results["Acc@1"] = acc1

        self._results["metric"] = acc1

        return copy.deepcopy(self._results)
Exemple #11
0
    def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None):
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_instances = num_instances
        self.num_pids_per_batch = batch_size // self.num_instances

        self.index_pid = defaultdict(list)
        self.pid_cam = defaultdict(list)
        self.pid_index = defaultdict(list)

        for index, info in enumerate(data_source):
            pid = info[1]
            # camid = info[2]
            camid = info[3]['domains']
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)

        self.pids = list(self.pid_index.keys())
        self.num_identities = len(self.pids)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
Exemple #12
0
    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            predictions = comm.gather(self._predictions, dst=0)
            predictions = list(itertools.chain(*predictions))

            if not comm.is_main_process():
                return {}
        else:
            predictions = self._predictions

        features = []
        pids = []
        # camids = []
        for prediction in predictions:
            features.append(prediction['feats'])
            pids.append(prediction['pids'])
            # camids.append(prediction['camids'])

        features = torch.cat(features, dim=0)
        pids = torch.cat(pids, dim=0).numpy()

        rerank_dist = compute_jaccard_distance(
            features,
            k1=self.cfg.CLUSTER.JACCARD.K1,
            k2=self.cfg.CLUSTER.JACCARD.K2,
        )
        pseudo_labels = self.cluster.fit_predict(rerank_dist)

        ARI_score = metrics.adjusted_rand_score(pids, pseudo_labels)

        return ARI_score
Exemple #13
0
    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            pred_logits = comm.gather(self.pred_logits)
            pred_logits = sum(pred_logits, [])

            labels = comm.gather(self.labels)
            labels = sum(labels, [])

            # fmt: off
            if not comm.is_main_process(): return {}
            # fmt: on
        else:
            pred_logits = self.pred_logits
            labels = self.labels

        pred_logits = torch.cat(pred_logits, dim=0)
        labels = torch.stack(labels)

        # measure accuracy and record loss
        acc1, = accuracy(pred_logits, labels, topk=(1, ))

        self._results = OrderedDict()
        self._results["Acc@1"] = acc1

        self._results["metric"] = acc1

        return copy.deepcopy(self._results)
Exemple #14
0
    def __call__(self, embedding, targets):
        embedding = F.normalize(embedding, dim=1)

        if comm.get_world_size() > 1:
            all_embedding = concat_all_gather(embedding)
            all_targets = concat_all_gather(targets)
        else:
            all_embedding = embedding
            all_targets = targets

        dist_mat = torch.matmul(embedding, all_embedding.t())

        N, M = dist_mat.size()
        is_pos = targets.view(N, 1).expand(N, M).eq(
            all_targets.view(M, 1).expand(M, N).t())
        is_neg = targets.view(N, 1).expand(N, M).ne(
            all_targets.view(M, 1).expand(M, N).t())

        s_p = dist_mat[is_pos].contiguous().view(N, -1)
        s_n = dist_mat[is_neg].contiguous().view(N, -1)

        alpha_p = F.relu(-s_p.detach() + 1 + self.m)
        alpha_n = F.relu(s_n.detach() + self.m)
        delta_p = 1 - self.m
        delta_n = self.m

        logit_p = -self.s * alpha_p * (s_p - delta_p)
        logit_n = self.s * alpha_n * (s_n - delta_n)

        loss = F.softplus(
            torch.logsumexp(logit_p, dim=1) +
            torch.logsumexp(logit_n, dim=1)).mean()

        return loss * self._scale
Exemple #15
0
    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            features = comm.gather(self.features)
            features = sum(features, [])

            # fmt: off
            if not comm.is_main_process(): return {}
            # fmt: on
        else:
            features = self.features

        features = torch.cat(features, dim=0)
        features = F.normalize(features, p=2, dim=1).numpy()

        self._results = OrderedDict()
        tpr, fpr, accuracy, best_thresholds = evaluate(features, self.labels)

        self._results["Accuracy"] = accuracy.mean() * 100
        self._results["Threshold"] = best_thresholds.mean()
        self._results["metric"] = accuracy.mean() * 100

        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)

        PathManager.mkdirs(self._output_dir)
        roc_curve.save(
            os.path.join(self._output_dir, self.dataset_name + "_roc.png"))

        return copy.deepcopy(self._results)
Exemple #16
0
    def build_test_loader(cls, cfg, test_set):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare testing loader")

        # test_loader = DataLoader(
        #     # Preprocessor(test_set),
        #     test_set,
        #     batch_size=cfg.TEST.IMS_PER_BATCH,
        #     num_workers=cfg.DATALOADER.NUM_WORKERS,
        #     shuffle=False,
        #     pin_memory=True,
        # )

        test_batch_size = cfg.TEST.IMS_PER_BATCH
        mini_batch_size = test_batch_size // comm.get_world_size()
        num_workers = cfg.DATALOADER.NUM_WORKERS
        data_sampler = samplers.InferenceSampler(len(test_set))
        batch_sampler = BatchSampler(data_sampler, mini_batch_size, False)
        test_loader = DataLoaderX(
            comm.get_local_rank(),
            dataset=test_set,
            batch_sampler=batch_sampler,
            num_workers=num_workers,  # save some memory
            collate_fn=fast_batch_collator,
            pin_memory=True,
        )

        return test_loader
Exemple #17
0
    def build_train_loader(cls,
                           cfg,
                           train_set=None,
                           sampler=None,
                           with_mem_idx=False):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare training loader")

        total_batch_size = cfg.SOLVER.IMS_PER_BATCH
        mini_batch_size = total_batch_size // comm.get_world_size()

        if sampler is None:
            num_instance = cfg.DATALOADER.NUM_INSTANCE
            sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
            logger.info("Using training sampler {}".format(sampler_name))

            if sampler_name == "TrainingSampler":
                sampler = samplers.TrainingSampler(len(train_set))
            elif sampler_name == "NaiveIdentitySampler":
                sampler = samplers.NaiveIdentitySampler(
                    train_set.img_items, mini_batch_size, num_instance)
            elif sampler_name == "BalancedIdentitySampler":
                sampler = samplers.BalancedIdentitySampler(
                    train_set.img_items, mini_batch_size, num_instance)
            elif sampler_name == "SetReWeightSampler":
                set_weight = cfg.DATALOADER.SET_WEIGHT
                sampler = samplers.SetReWeightSampler(train_set.img_items,
                                                      mini_batch_size,
                                                      num_instance, set_weight)
            elif sampler_name == "ImbalancedDatasetSampler":
                sampler = samplers.ImbalancedDatasetSampler(
                    train_set.img_items)
            else:
                raise ValueError(
                    "Unknown training sampler: {}".format(sampler_name))

        iters = cfg.SOLVER.ITERS
        num_workers = cfg.DATALOADER.NUM_WORKERS
        batch_sampler = BatchSampler(sampler, mini_batch_size, True)

        train_loader = IterLoader(
            DataLoader(
                Preprocessor(train_set, with_mem_idx),
                num_workers=num_workers,
                batch_sampler=batch_sampler,
                pin_memory=True,
            ),
            length=iters,
        )
        # train_loader = DataLoaderX(
        #     comm.get_local_rank(),
        #     dataset=Preprocessor(train_set, with_mem_idx),
        #     num_workers=num_workers,
        #     batch_sampler=batch_sampler,
        #     collate_fn=fast_batch_collator,
        #     pin_memory=True,
        # )

        return train_loader
Exemple #18
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)

        optimizer_ckpt = dict(optimizer=optimizer)
        if cfg.SOLVER.FP16_ENABLED:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
            optimizer_ckpt.update(dict(amp=amp))

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            # model = DistributedDataParallel(
            #     model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            # )
            model = DistributedDataParallel(model, delay_allreduce=True)

        self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else
                         SimpleTrainer)(model, data_loader, optimizer)

        self.iters_per_epoch = len(
            data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
        self.scheduler = self.build_lr_scheduler(cfg, optimizer,
                                                 self.iters_per_epoch)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            **optimizer_ckpt,
            **self.scheduler,
        )

        self.start_epoch = 0
        self.max_epoch = cfg.SOLVER.MAX_EPOCH
        self.max_iter = self.max_epoch * self.iters_per_epoch
        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
        self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
        self.cfg = cfg

        self.register_hooks(self.build_hooks())
Exemple #19
0
    def __call__(self, embedding, targets):
        """Forward pass for all input predictions: preds - (batch_size x feat_dims) """

        # ------ differentiable ranking of all retrieval set ------
        embedding = F.normalize(embedding, dim=1)

        feat_dim = embedding.size(1)

        # For distributed training, gather all features from different process.
        if comm.get_world_size() > 1:
            all_embedding = concat_all_gather(embedding)
            all_targets = concat_all_gather(targets)
        else:
            all_embedding = embedding
            all_targets = targets

        sim_dist = torch.matmul(embedding, all_embedding.t())
        N, M = sim_dist.size()

        # Compute the mask which ignores the relevance score of the query to itself
        mask_indx = 1.0 - torch.eye(M, device=sim_dist.device)
        mask_indx = mask_indx.unsqueeze(dim=0).repeat(N, 1, 1)  # (N, M, M)

        # sim_dist -> N, 1, M -> N, M, N
        sim_dist_repeat = sim_dist.unsqueeze(dim=1).repeat(1, M,
                                                           1)  # (N, M, M)
        # sim_dist_repeat_t = sim_dist.t().unsqueeze(dim=1).repeat(1, N, 1)  # (N, N, M)

        # Compute the difference matrix
        sim_diff = sim_dist_repeat - sim_dist_repeat.permute(0, 2,
                                                             1)  # (N, M, M)

        # Pass through the sigmoid
        sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask_indx

        # Compute all the rankings
        sim_all_rk = torch.sum(sim_sg, dim=-1) + 1  # (N, N)

        pos_mask = targets.view(N, 1).expand(N, M).eq(
            all_targets.view(M, 1).expand(M, N).t()).float()  # (N, M)

        pos_mask_repeat = pos_mask.unsqueeze(1).repeat(1, M, 1)  # (N, M, M)

        # Compute positive rankings
        pos_sim_sg = sim_sg * pos_mask_repeat
        sim_pos_rk = torch.sum(pos_sim_sg, dim=-1) + 1  # (N, N)

        # sum the values of the Smooth-AP for all instances in the mini-batch
        ap = 0
        group = N // self.num_id
        for ind in range(self.num_id):
            pos_divide = torch.sum(
                sim_pos_rk[(ind * group):((ind + 1) * group),
                           (ind * group):((ind + 1) * group)] /
                (sim_all_rk[(ind * group):((ind + 1) * group),
                            (ind * group):((ind + 1) * group)]))
            ap += pos_divide / torch.sum(pos_mask[ind * group]) / N
        return 1 - ap
Exemple #20
0
    def __init__(self, data_source: str, batch_size: int, num_instances: int, delete_rem: bool, seed: Optional[int] = None, cfg = None):
        self.data_source = data_source
        self.batch_size = batch_size
        self.num_instances = num_instances
        self.num_pids_per_batch = batch_size // self.num_instances
        self.delete_rem = delete_rem

        self.index_pid = defaultdict(list)
        self.pid_cam = defaultdict(list)
        self.pid_index = defaultdict(list)

        for index, info in enumerate(data_source):
            pid = info[1]
            camid = info[2]
            self.index_pid[index] = pid
            self.pid_cam[pid].append(camid)
            self.pid_index[pid].append(index)

        self.pids = list(self.pid_index.keys())
        self.num_identities = len(self.pids)

        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()


        val_pid_index = [len(x) for x in self.pid_index.values()]
        min_v = min(val_pid_index)
        max_v = max(val_pid_index)
        hist_pid_index = [val_pid_index.count(x) for x in range(min_v, max_v+1)]
        num_print = 5
        for i, x in enumerate(range(min_v, min_v+min(len(hist_pid_index), num_print))):
            print('dataset histogram [bin:{}, cnt:{}]'.format(x, hist_pid_index[i]))
        print('...')
        print('dataset histogram [bin:{}, cnt:{}]'.format(max_v, val_pid_index.count(max_v)))

        val_pid_index_upper = []
        for x in val_pid_index:
            v_remain = x % self.num_instances
            if v_remain == 0:
                val_pid_index_upper.append(x)
            else:
                if self.delete_rem:
                    if x < self.num_instances:
                        val_pid_index_upper.append(x - v_remain + self.num_instances)
                    else:
                        val_pid_index_upper.append(x - v_remain)
                else:
                    val_pid_index_upper.append(x - v_remain + self.num_instances)

        total_images = sum(val_pid_index_upper)
        total_images = total_images - (total_images % self.batch_size) - self.batch_size # approax
        self.total_images = total_images
Exemple #21
0
    def __init__(self, embedding_size, num_classes, sample_rate, cls_type,
                 scale, margin):
        super().__init__()

        self.embedding_size = embedding_size
        self.num_classes = num_classes
        self.sample_rate = sample_rate

        self.world_size = comm.get_world_size()
        self.rank = comm.get_rank()
        self.local_rank = comm.get_local_rank()
        self.device = torch.device(f'cuda:{self.local_rank}')

        self.num_local: int = self.num_classes // self.world_size + int(
            self.rank < self.num_classes % self.world_size)
        self.class_start: int = self.num_classes // self.world_size * self.rank + \
                                min(self.rank, self.num_classes % self.world_size)
        self.num_sample: int = int(self.sample_rate * self.num_local)

        self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale,
                                                        margin)
        """ TODO: consider resume training
        if resume:
            try:
                self.weight: torch.Tensor = torch.load(self.weight_name)
                logging.info("softmax weight resume successfully!")
            except (FileNotFoundError, KeyError, IndexError):
                self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
                logging.info("softmax weight resume fail!")

            try:
                self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
                logging.info("softmax weight mom resume successfully!")
            except (FileNotFoundError, KeyError, IndexError):
                self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
                logging.info("softmax weight mom resume fail!")
        else:
        """
        self.weight = torch.normal(0,
                                   0.01, (self.num_local, self.embedding_size),
                                   device=self.device)
        self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
        logger.info("softmax weight init successfully!")
        logger.info("softmax weight mom init successfully!")
        self.stream: torch.cuda.Stream = torch.cuda.Stream(self.local_rank)

        self.index = None
        if int(self.sample_rate) == 1:
            self.update = lambda: 0
            self.sub_weight = nn.Parameter(self.weight)
            self.sub_weight_mom = self.weight_mom
        else:
            self.sub_weight = nn.Parameter(
                torch.empty((0, 0), device=self.device))
Exemple #22
0
def build_reid_train_loader(cfg, mapper=None, **kwargs):
    """
    Build reid train loader

    Args:
        cfg : image file path
        mapper : one of the supported image modes in PIL, or "BGR"

    Returns:
        torch.utils.data.DataLoader: a dataloader.
    """
    cfg = cfg.clone()

    train_items = list()
    for d in cfg.DATASETS.NAMES:
        dataset = DATASET_REGISTRY.get(d)(root=_root,
                                          combineall=cfg.DATASETS.COMBINEALL,
                                          **kwargs)
        if comm.is_main_process():
            dataset.show_train()
        train_items.extend(dataset.train)

    if mapper is not None:
        transforms = mapper
    else:
        transforms = build_transforms(cfg, is_train=True)

    train_set = CommDataset(train_items, transforms, relabel=True)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    num_instance = cfg.DATALOADER.NUM_INSTANCE
    mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()

    if cfg.DATALOADER.PK_SAMPLER:
        if cfg.DATALOADER.NAIVE_WAY:
            data_sampler = samplers.NaiveIdentitySampler(
                train_set.img_items, mini_batch_size, num_instance)
        else:
            data_sampler = samplers.BalancedIdentitySampler(
                train_set.img_items, mini_batch_size, num_instance)
    else:
        data_sampler = samplers.TrainingSampler(len(train_set))
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        data_sampler, mini_batch_size, True)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return train_loader
Exemple #23
0
def _train_loader_from_config(cfg,
                              *,
                              train_set=None,
                              transforms=None,
                              sampler=None,
                              **kwargs):
    if transforms is None:
        transforms = build_transforms(cfg, is_train=True)

    if train_set is None:
        train_items = list()
        for d in cfg.DATASETS.NAMES:
            data = DATASET_REGISTRY.get(d)(root=_root, **kwargs)
            if comm.is_main_process():
                data.show_train()
            train_items.extend(data.train)

        train_set = CommDataset(train_items, transforms, relabel=True)

    if sampler is None:
        sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
        num_instance = cfg.DATALOADER.NUM_INSTANCE
        mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()

        logger = logging.getLogger(__name__)
        logger.info("Using training sampler {}".format(sampler_name))
        if sampler_name == "TrainingSampler":
            sampler = samplers.TrainingSampler(len(train_set))
        elif sampler_name == "NaiveIdentitySampler":
            sampler = samplers.NaiveIdentitySampler(train_set.img_items,
                                                    mini_batch_size,
                                                    num_instance)
        elif sampler_name == "BalancedIdentitySampler":
            sampler = samplers.BalancedIdentitySampler(train_set.img_items,
                                                       mini_batch_size,
                                                       num_instance)
        elif sampler_name == "SetReWeightSampler":
            set_weight = cfg.DATALOADER.SET_WEIGHT
            sampler = samplers.SetReWeightSampler(train_set.img_items,
                                                  mini_batch_size,
                                                  num_instance, set_weight)
        elif sampler_name == "ImbalancedDatasetSampler":
            sampler = samplers.ImbalancedDatasetSampler(train_set.img_items)
        else:
            raise ValueError(
                "Unknown training sampler: {}".format(sampler_name))

    return {
        "train_set": train_set,
        "sampler": sampler,
        "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
        "num_workers": cfg.DATALOADER.NUM_WORKERS,
    }
Exemple #24
0
    def __init__(self, size: int):
        """
        Args:
            size (int): the total number of data of the underlying dataset to sample from
        """
        self._size = size
        assert size > 0
        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()

        shard_size = (self._size - 1) // self._world_size + 1
        begin = shard_size * self._rank
        end = min(shard_size * (self._rank + 1), self._size)
        self._local_indices = range(begin, end)
Exemple #25
0
def build_reid_train_loader(cfg):
    cfg = cfg.clone()
    cfg.defrost()
    train_items = list()
    for d in cfg.DATASETS.NAMES:
        dataset = DATASET_REGISTRY.get(d)(root=_root,
                                          combineall=cfg.DATASETS.COMBINEALL)

        if comm.is_main_process():
            dataset.show_train()
        train_items.extend(dataset.train)

    iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH
    cfg.SOLVER.MAX_ITER *= iters_per_epoch
    train_transforms = build_transforms(cfg, is_train=True)

    if not cfg.DATALOADER.IS_CLO_CHANGES:
        train_set = CommDataset(train_items, train_transforms, relabel=True)
    else:
        # For clothes changes datasets
        train_set = CCDatasets(train_items, train_transforms, relabel=True)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    num_instance = cfg.DATALOADER.NUM_INSTANCE
    mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()

    if cfg.DATALOADER.PK_SAMPLER:
        if cfg.DATALOADER.NAIVE_WAY:
            data_sampler = samplers.NaiveIdentitySampler(
                train_set.img_items, cfg.SOLVER.IMS_PER_BATCH, num_instance,
                None, True)

        else:
            data_sampler = samplers.BalancedIdentitySampler(
                train_set.img_items, cfg.SOLVER.IMS_PER_BATCH, num_instance)
    else:
        data_sampler = samplers.TrainingSampler(len(train_set))

    batch_sampler = torch.utils.data.sampler.BatchSampler(
        data_sampler, mini_batch_size, True)

    train_loader = torch.utils.data.DataLoader(
        train_set,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return train_loader
Exemple #26
0
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(
                logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        data_loader = self.build_train_loader(cfg)
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(model,
                                            device_ids=[comm.get_local_rank()],
                                            broadcast_buffers=False)

        super().__init__(model, data_loader, optimizer, cfg.SOLVER.BASE_LR,
                         cfg.MODEL.LOSSES.CENTER.LR,
                         cfg.MODEL.LOSSES.CENTER.SCALE, cfg.SOLVER.AMP_ENABLED)

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        self.start_iter = 0
        if cfg.SOLVER.SWA.ENABLED:
            self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
        else:
            self.max_iter = cfg.SOLVER.MAX_ITER

        self.cfg = cfg

        self.register_hooks(self.build_hooks())
Exemple #27
0
def default_setup(cfg, args):
    """
    Perform some basic common setups at the beginning of a job, including:
    1. Set up the detectron2 logger
    2. Log basic information about environment, cmdline arguments, and config
    3. Backup the config to the output directory
    Args:
        cfg (CfgNode): the full config to be used
        args (argparse.NameSpace): the command line arguments to be logged
    """
    output_dir = cfg.OUTPUT_DIR
    if comm.is_main_process() and output_dir:
        PathManager.mkdirs(output_dir)

    rank = comm.get_rank()
    setup_logger(output_dir, distributed_rank=rank, name="fvcore")
    logger = setup_logger(output_dir, distributed_rank=rank)

    logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
    logger.info("Environment info:\n" + collect_env_info())

    logger.info("Command line arguments: " + str(args))
    if hasattr(args, "config_file") and args.config_file != "":
        logger.info(
            "Contents of args.config_file={}:\n{}".format(
                args.config_file, PathManager.open(args.config_file, "r").read()
            )
        )

    logger.info("Running with full config:\n{}".format(cfg))
    if comm.is_main_process() and output_dir:
        # Note: some of our scripts may expect the existence of
        # config.yaml in output directory
        path = os.path.join(output_dir, "config.yaml")
        with PathManager.open(path, "w") as f:
            f.write(cfg.dump())
        logger.info("Full config saved to {}".format(os.path.abspath(path)))

    # make sure each worker has a different, yet deterministic seed if specified
    seed_all_rng()

    # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
    # typical validation set.
    if not (hasattr(args, "eval_only") and args.eval_only):
        torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            features = comm.gather(self.features)
            features = sum(features, [])

            labels = comm.gather(self.labels)
            labels = sum(labels, [])

            # fmt: off
            if not comm.is_main_process(): return {}
            # fmt: on
        else:
            features = self.features
            labels = self.labels

        features = torch.cat(features, dim=0)
        # query feature, person ids and camera ids
        query_features = features[:self._num_query]
        query_labels = np.asarray(labels[:self._num_query])

        # gallery features, person ids and camera ids
        gallery_features = features[self._num_query:]
        gallery_pids = np.asarray(labels[self._num_query:])

        self._results = OrderedDict()

        if self._num_query == len(features):
            cmc = recall_at_ks(query_features,
                               query_labels,
                               self.recalls,
                               cosine=True)
        else:
            cmc = recall_at_ks(query_features,
                               query_labels,
                               self.recalls,
                               gallery_features,
                               gallery_pids,
                               cosine=True)

        for r in self.recalls:
            self._results['Recall@{}'.format(r)] = cmc[r]
        self._results["metric"] = cmc[self.recalls[0]]

        return copy.deepcopy(self._results)
Exemple #29
0
    def __call__(self, embedding, targets):
        embedding = nn.functional.normalize(embedding, dim=1)

        if comm.get_world_size() > 1:
            all_embedding = concat_all_gather(embedding)
            all_targets = concat_all_gather(targets)
        else:
            all_embedding = embedding
            all_targets = targets

        dist_mat = torch.matmul(embedding, all_embedding.t())

        N, M = dist_mat.size()
        is_pos = targets.view(N, 1).expand(N, M).eq(
            all_targets.view(M, 1).expand(M, N).t()).float()

        # Compute the mask which ignores the relevance score of the query to itself
        if M > N:
            identity_indx = torch.eye(N, N, device=is_pos.device)
            remain_indx = torch.zeros(N, M - N, device=is_pos.device)
            identity_indx = torch.cat((identity_indx, remain_indx), dim=1)
            is_pos = is_pos - identity_indx
        else:
            is_pos = is_pos - torch.eye(N, N, device=is_pos.device)

        is_neg = targets.view(N, 1).expand(N, M).ne(
            all_targets.view(M, 1).expand(M, N).t())

        s_p = dist_mat * is_pos
        s_n = dist_mat * is_neg

        alpha_p = torch.clamp_min(-s_p.detach() + 1 + self._m, min=0.)
        alpha_n = torch.clamp_min(s_n.detach() + self._m, min=0.)
        delta_p = 1 - self._m
        delta_n = self._m

        logit_p = -self._s * alpha_p * (s_p - delta_p)
        logit_n = self._s * alpha_n * (s_n - delta_n)

        loss = nn.functional.softplus(
            torch.logsumexp(logit_p, dim=1) +
            torch.logsumexp(logit_n, dim=1)).mean()

        return loss * self._scale
Exemple #30
0
def build_face_test_loader(cfg, dataset_name, **kwargs):
    dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
    if comm.is_main_process():
        dataset.show_test()

    test_set = FaceCommDataset(dataset.carray, dataset.is_same)

    mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=4,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader, test_set.labels