def validation_end(self, outputs):
        feats,pids,camids = [],[],[]
        for o in outputs:
            feats.append(o['feats'])
            pids.extend(o['pids'])
            camids.extend(o['camids'])
        feats = torch.cat(feats, dim=0)
        if self.cfg.TEST.NORM:
            feats = F.normalize(feats, p=2, dim=1)
        # query
        qf = feats[:self.num_query]
        q_pids = np.asarray(pids[:self.num_query])
        q_camids = np.asarray(camids[:self.num_query])
        # gallery
        gf = feats[self.num_query:]
        g_pids = np.asarray(pids[self.num_query:])
        g_camids = np.asarray(camids[self.num_query:])

        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                  torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.cpu().numpy()
        cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
        self.logger.info(f"Test Results - Epoch: {self.current_epoch + 1}")
        self.logger.info(f"mAP: {mAP:.1%}")
        for r in [1, 5, 10]:
            self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
        tqdm_dic = {'rank1': cmc[0], 'mAP': mAP}
        return tqdm_dic
Exemple #2
0
 def on_epoch_end(self, epoch, **kwargs: Any):
     # test model performance
     if (epoch + 1) % self._eval_period == 0:
         self._logger.info('Testing ...')
         feats, pids, camids = [], [], []
         self.learn.model.eval()
         with torch.no_grad():
             for imgs, _ in self._test_dl:
                 feat = self.learn.model(imgs)
                 feats.append(feat)
         feats = torch.cat(feats, dim=0)
         if self._norm:
             feats = F.normalize(feats, p=2, dim=1)
         # query
         qf = feats[:self._num_query]
         # gallery
         gf = feats[self._num_query:]
         m, n = qf.shape[0], gf.shape[0]
         distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                   torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
         distmat.addmm_(1, -2, qf, gf.t())
         distmat = to_np(distmat)
         cmc, mAP = evaluate(distmat, self.q_pids, self.g_pids,
                             self.q_camids, self.g_camids)
         self._logger.info(f"Test Results - Epoch: {epoch+1}")
         self._logger.info(f"mAP: {mAP:.1%}")
         for r in [1, 5, 10]:
             self._logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r-1]:.1%}")
         self.learn.save("model_{}".format(epoch))
def inference_with_distmat(
    cfg,
    test_dataloader,
    num_query,
    distmat,
):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    pids, camids = [], []
    test_prefetcher = data_prefetcher(test_dataloader, cfg)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))
        batch = test_prefetcher.next()

    # query
    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])
    # gallery
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    #distmat = re_ranking(qf, gf, k1=14, k2=4, lambda_value=0.4)

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
    logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")
Exemple #4
0
    def test(self):
        # convert to eval mode
        self.model.eval()

        metric_dict = list()
        for val_dataset_name, val_dataloader, num_query in zip(
                self.cfg.DATASETS.TEST_NAMES, self.val_dataloader_collection,
                self.num_query_len_collection):
            feats, pids, camids = [], [], []
            val_prefetcher = data_prefetcher_mask(val_dataloader)
            batch = val_prefetcher.next()
            while batch[0] is not None:
                img, mask, pid, camid = batch
                adj_batch = self.adj.repeat(img.size(0), 1, 1)

                with torch.no_grad():
                    output = self.model(img, img, mask, adj_batch)

#                 feat = output[1]
                feat = torch.cat([output[1], output[3]], dim=1)

                feats.append(feat)
                pids.extend(pid.cpu().numpy())
                camids.extend(np.asarray(camid))

                batch = val_prefetcher.next()

            feats = torch.cat(feats, dim=0)
            if self.cfg.TEST.NORM:
                feats = F.normalize(feats, p=2, dim=1)
            # query
            qf = feats[:num_query]
            q_pids = np.asarray(pids[:num_query])
            q_camids = np.asarray(camids[:num_query])
            # gallery
            gf = feats[num_query:]
            g_pids = np.asarray(pids[num_query:])
            g_camids = np.asarray(camids[num_query:])

            # m, n = qf.shape[0], gf.shape[0]
            distmat = torch.mm(qf, gf.t()).cpu().numpy()
            # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
            #           torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
            # distmat.addmm_(1, -2, qf, gf.t())
            # distmat = distmat.numpy()
            cmc, mAP = evaluate(-distmat, q_pids, g_pids, q_camids, g_camids)
            self.logger.info(
                f"Test Results on {val_dataset_name} - Epoch: {self.current_epoch}"
            )
            self.logger.info(f"mAP: {mAP:.1%}")
            for r in [1, 5, 10]:
                self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")

            self.writer.add_scalar('rank1', cmc[0], self.global_step)
            self.writer.add_scalar('mAP', mAP, self.global_step)
            metric_dict.append({'rank1': cmc[0], 'mAP': mAP})
        # convert to train mode
        self.model.train()
        return metric_dict[0]
    def test(self):
        # convert to eval mode
        self.model.eval()

        feats, pids, camids = [], [], []
        val_prefetcher = data_prefetcher(self.val_dataloader, self.cfg)
        batch = val_prefetcher.next()
        while batch[0] is not None:
            img, pid, camid = batch
            with torch.no_grad():
                feat = self.model(img)
            if isinstance(feat, tuple):
                feats.append(feat[0])
            else:
                feats.append(feat)

            pids.extend(pid.cpu().numpy())
            camids.extend(np.asarray(camid))

            batch = val_prefetcher.next()

        ####
        feats = torch.cat(feats, dim=0)
        if self.cfg.TEST.NORM:
            feats = F.normalize(feats, p=2, dim=1)

        # query
        qf = feats[:self.num_query]

        q_pids = np.asarray(pids[:self.num_query])
        q_camids = np.asarray(camids[:self.num_query])
        # gallery
        gf = feats[self.num_query:]

        g_pids = np.asarray(pids[self.num_query:])
        g_camids = np.asarray(camids[self.num_query:])

        # TODO: 添加rerank的测评结果
        # m, n = qf.shape[0], gf.shape[0]
        distmat = -torch.mm(qf, gf.t()).cpu().numpy()

        # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
        #           torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        # distmat.addmm_(1, -2, qf, gf.t())
        # distmat = distmat.numpy()
        cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
        self.logger.info(f"Test Results - Epoch: {self.current_epoch}")
        self.logger.info(f"mAP: {mAP:.1%}")
        for r in [1, 5, 10]:
            self.logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")

        self.writer.add_scalar('rank1', cmc[0], self.global_step)
        self.writer.add_scalar('mAP', mAP, self.global_step)
        metric_dict = {'rank1': cmc[0], 'mAP': mAP}
        # convert to train mode
        self.model.train()
        return metric_dict
def inference_no_rerank(cfg, model, test_dataloader, num_query):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    model.eval()

    feats, pids, camids = [], [], []
    test_prefetcher = data_prefetcher(test_dataloader, cfg)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        with torch.no_grad():
            feat = model(img)
            #feat = model(torch.flip(img, [3]))

        if isinstance(feat, tuple):
            feats.append(feat[0])
            #local_feats.append(feat[1])
        else:
            feats.append(feat)
        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = test_prefetcher.next()

    feats = torch.cat(feats, dim=0)
    if cfg.TEST.NORM:
        feats = F.normalize(feats, p=2, dim=1)

    # query
    qf = feats[:num_query]

    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])

    # gallery
    gf = feats[num_query:]
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    distmat = -torch.mm(qf, gf.t()).cpu().numpy()

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
    logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")

    index = np.argsort(distmat, axis=1)  # from small to large

    new_gallery_index = np.unique(index[:, :top_k].reshape(-1))

    print('new_gallery_index', len(new_gallery_index))

    return new_gallery_index
Exemple #7
0
def inference(cfg, model, test_dataloader, num_query):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    model.eval()

    feats, pids, camids = [], [], []
    test_prefetcher = data_prefetcher(test_dataloader)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        with torch.no_grad():
            feat = model(img)
        feats.append(feat)
        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = test_prefetcher.next()

    feats = torch.cat(feats, dim=0)
    if cfg.TEST.NORM:
        feats = F.normalize(feats, p=2, dim=1)
    # query
    qf = feats[:num_query]
    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])
    # gallery
    gf = feats[num_query:]
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    # cosine distance
    distmat = torch.mm(qf, gf.t()).cpu().numpy()

    # euclidean distance
    # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
    #           torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
    # distmat.addmm_(1, -2, qf, gf.t())
    # distmat = distmat.numpy()
    cmc, mAP = evaluate(-distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
def inference(cfg, model, data_bunch, test_labels, num_query):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    pids = []
    camids = []
    for p, c in test_labels:
        pids.append(p)
        camids.append(c)
    q_pids = np.asarray(pids[:num_query])
    g_pids = np.asarray(pids[num_query:])
    q_camids = np.asarray(camids[:num_query])
    g_camids = np.asarray(camids[num_query:])

    feats = []
    model.eval()
    for imgs, _ in data_bunch.test_dl:
        with torch.no_grad():
            feat = model(imgs)
        feats.append(feat)
    feats = torch.cat(feats, dim=0)

    qf = feats[:num_query]
    gf = feats[num_query:]
    m, n = qf.shape[0], gf.shape[0]

    # Cosine distance
    distmat = torch.mm(F.normalize(qf), F.normalize(gf).t())

    # Euclid distance
    # distmat = torch.pow(qf,2).sum(dim=1,keepdim=True).expand(m,n) + \
    # torch.pow(gf,2).sum(dim=1,keepdim=True).expand(n,m).t()
    # distmat.addmm_(1, -2, qf, gf.t())

    distmat = to_np(distmat)

    # Compute CMC and mAP.
    cmc, mAP = evaluate(-distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info('Compute CMC Curve')
    logger.info("mAP: {:.1%}".format(mAP))
    for r in [1, 5, 10]:
        logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
def inference_aligned_flipped(cfg, model, test_dataloader, num_query,
                              use_local_feature, use_rerank,
                              use_cross_feature):
    """
    inference an aligned net with flipping and two pairs of global feature and local feature
    :param cfg:
    :param model:
    :param test_dataloader:
    :param num_query:
    :return:
    """
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing aligned with flipping")

    model.eval()

    pids, camids = [], []
    gfs, bn_gfs, lfs, bn_lfs = [], [], [], []
    gfs_flipped, bn_gfs_flipped, lfs_flipped, bn_lfs_flipped = [], [], [], []

    test_prefetcher = data_prefetcher(test_dataloader, cfg)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        with torch.no_grad():
            gf, bn_gf, lf, bn_lf = model(img)
            gff, bn_gff, lff, bn_lff = model(torch.flip(img, [3]))

        # 4 features
        gfs.append(gf.cpu())
        bn_gfs.append(bn_gf.cpu())

        if use_local_feature:
            lfs.append(lf.cpu())
            bn_lfs.append(bn_lf.cpu())

        # 4 features flipped
        gfs_flipped.append(gff.cpu())
        bn_gfs_flipped.append(bn_gff.cpu())

        if use_local_feature:
            lfs_flipped.append(lff.cpu())
            bn_lfs_flipped.append(bn_lff.cpu())

        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = test_prefetcher.next()

    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    logger.info(
        f"use_cross_feature = {use_cross_feature}, use_local_feature = {use_local_feature}, use_rerank = {use_rerank}"
    )

    if use_cross_feature:
        logger.info("Computing distmat with bn_gf (+ lf)")
        distmat2 = compute_distmat(cfg,
                                   num_query,
                                   bn_gfs,
                                   bn_gfs_flipped,
                                   lfs,
                                   lfs_flipped,
                                   theta=0.45,
                                   use_local_feature=use_local_feature,
                                   use_rerank=use_rerank)
        distmat = distmat2
        #distmat = (distmat1 + distmat2) / 2
    else:
        logger.info("Computing distmat with gf + bn_lf")
        distmat1 = compute_distmat(cfg,
                                   num_query,
                                   gfs,
                                   gfs_flipped,
                                   bn_lfs,
                                   bn_lfs_flipped,
                                   theta=0.95,
                                   use_local_feature=use_local_feature,
                                   use_rerank=use_rerank)
        distmat = distmat1
        #distmat1 = None
        #distmat2 = None

    #distmat = original_distmat
    #distmat[:, new_gallery_index] = distmat1 - 100

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
    logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")
def inference_aligned(cfg, model, test_dataloader, num_query):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    model.eval()

    feats, pids, camids = [], [], []
    local_feats = []
    test_prefetcher = data_prefetcher(test_dataloader)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        with torch.no_grad():
            feat = model(img)
            #feat = model(torch.flip(img, [3]))

        if isinstance(feat, tuple):
            feats.append(feat[0])
            local_feats.append(feat[1])
        else:
            feats.append(feat)
        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = test_prefetcher.next()

    feats = torch.cat(feats, dim=0)
    if len(local_feats) > 0:
        local_feats = torch.cat(local_feats, dim=0)
    if cfg.TEST.NORM:
        feats = F.normalize(feats, p=2, dim=1)
        # 局部特征是三维的,不做归一化 (对结果没影响)
        #if len(local_feats) > 0:
        #    local_feats = F.normalize(local_feats, p=2, dim=1)

    # query
    qf = feats[:num_query]
    if len(local_feats) > 0:
        lqf = local_feats[:num_query]

    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])

    # gallery
    gf = feats[num_query:]
    if len(local_feats) > 0:
        lgf = local_feats[num_query:]
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    if len(local_feats) > 0:
        #if True:
        # calculate the local distance
        lqf = lqf.permute(0, 2, 1)
        lgf = lgf.permute(0, 2, 1)
        local_qg_distmat = low_memory_local_dist(lqf.cpu().numpy(),
                                                 lgf.cpu().numpy(),
                                                 aligned=True)
        local_qq_distmat = low_memory_local_dist(lqf.cpu().numpy(),
                                                 lqf.cpu().numpy(),
                                                 aligned=True)
        local_gg_distmat = low_memory_local_dist(lgf.cpu().numpy(),
                                                 lgf.cpu().numpy(),
                                                 aligned=True)
        local_distmat = np.concatenate([
            np.concatenate([local_qq_distmat, local_qg_distmat], axis=1),
            np.concatenate([local_qg_distmat.T, local_gg_distmat], axis=1)
        ],
                                       axis=0)

    else:
        local_distmat = None

    # use reranking
    logger.info("use reranking")
    #distmat = re_ranking(qf, gf, k1=14, k2=4, lambda_value=0.4)

    search_param = False
    search_theta = True
    if search_param:
        best_score = 0
        best_param = []
        for k1 in range(5, 9):
            for k2 in range(1, k1):
                for l in np.linspace(0, 0.5, 11):
                    distmat = re_ranking(qf, gf, k1=k1, k2=k2, lambda_value=l)
                    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids,
                                        g_camids)
                    score = (cmc[0] + mAP) / 2
                    #logger.info(f"mAP: {mAP:.1%}")
                    print('k1, k2, l', k1, k2, np.around(l, 2),
                          'r1, mAP, score', np.around(cmc[0], 4),
                          np.around(mAP, 4), np.around(score, 4))
                    if score > best_score:
                        best_score = score
                        best_param = [k1, k2, l]
        print('Best Param', best_param)
        distmat = re_ranking(qf,
                             gf,
                             k1=best_param[0],
                             k2=best_param[1],
                             lambda_value=best_param[2],
                             local_distmat=local_distmat,
                             only_local=False)
    elif search_theta:
        best_score = 0
        for theta in np.linspace(0, 1.0, 11):
            distmat = re_ranking(qf,
                                 gf,
                                 k1=6,
                                 k2=2,
                                 lambda_value=0.3,
                                 local_distmat=local_distmat,
                                 theta_value=theta,
                                 only_local=False)  # (current best)
            cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
            score = (cmc[0] + mAP) / 2
            print('theta', theta, 'r1, mAP, score', np.around(cmc[0], 4),
                  np.around(mAP, 4), np.around(score, 4))
            if score > best_score:
                best_score = score
                best_param = theta
        print('Best Param', best_param)
        distmat = re_ranking(qf,
                             gf,
                             k1=6,
                             k2=2,
                             lambda_value=0.3,
                             local_distmat=local_distmat,
                             theta_value=best_param,
                             only_local=False)  # (current best)
    else:
        distmat = re_ranking(qf,
                             gf,
                             k1=6,
                             k2=2,
                             lambda_value=0.3,
                             local_distmat=local_distmat,
                             only_local=False,
                             theta_value=0.9)  #(current best)
        #distmat = re_ranking(qf, gf, k1=6, k2=2, lambda_value=0.4) # try

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
    logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")
def inference_aligned_flipped(cfg, model, test_dataloader, num_query,
                              use_local_feature, use_rerank,
                              use_cross_feature):
    """
    inference an aligned net with flipping and two pairs of global feature and local feature
    :param cfg:
    :param model:
    :param test_dataloader:
    :param num_query:
    :return:
    """
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing aligned with flipping")

    model.eval()

    pids, camids = [], []
    gfs, bn_gfs, lfs, bn_lfs = [], [], [], []
    gfs_flipped, bn_gfs_flipped, lfs_flipped, bn_lfs_flipped = [], [], [], []

    test_prefetcher = data_prefetcher(test_dataloader, cfg)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        with torch.no_grad():
            ret = model(img)
            ret_flip = model(torch.flip(img, [3]))
            if len(ret) == 4:
                gf, bn_gf, lf, bn_lf = ret
                gff, bn_gff, lff, bn_lff = ret_flip
            elif len(ret) == 2:
                gf, bn_gf = ret
                gff, bn_gff = ret_flip
                lf, bn_lf = None, None
                lff, bn_lff = None, None
            elif ret is not tuple:
                bn_gf = ret[:, :2048]
                gf = ret[:, 2048:]

                gff = ret_flip[:, :2048]
                bn_gff = ret_flip[:, 2048:]
                lf, bn_lf = None, None
                lff, bn_lff = None, None
            else:
                # print('ret', ret.size())
                raise Exception("Unknown model returns, length = ", len(ret))

        # 4 features
        gfs.append(gf.cpu())
        bn_gfs.append(bn_gf.cpu())

        if use_local_feature:
            if use_cross_feature:
                lfs.append(lf.cpu())
            else:
                bn_lfs.append(bn_lf.cpu())

        # 4 features flipped
        gfs_flipped.append(gff.cpu())
        bn_gfs_flipped.append(bn_gff.cpu())

        if use_local_feature:
            if use_cross_feature:
                lfs_flipped.append(lff.cpu())
            else:
                bn_lfs_flipped.append(bn_lff.cpu())

        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = test_prefetcher.next()

    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    logger.info(
        f"use_local_feature = {use_local_feature}, use_rerank = {use_rerank}")

    logger.info("Computing distmat with bn_gf")
    distmat2 = compute_distmat(cfg,
                               num_query,
                               bn_gfs,
                               bn_gfs_flipped,
                               lfs,
                               lfs_flipped,
                               theta=0.45,
                               use_local_feature=use_local_feature,
                               use_rerank=use_rerank)

    logger.info("Computing distmat with gf + bn_lf")
    distmat1 = compute_distmat(cfg,
                               num_query,
                               gfs,
                               gfs_flipped,
                               bn_lfs,
                               bn_lfs_flipped,
                               theta=0.95,
                               use_local_feature=use_local_feature,
                               use_rerank=use_rerank)

    for theta in np.linspace(0, 1, 21):
        #theta = 0.55
        distmat = distmat1 * (1 - theta) + distmat2 * theta

        cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
        logger.info(f"theta: {theta:.2%} mAP: {mAP:.1%}")
        for r in [1, 5, 10]:
            logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
        logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")
Exemple #12
0
def inference(
        cfg,
        model,
        test_dataloader_collection,
        num_query_collection,
        is_vis=False,
        test_collection=None,
        use_mask=True,
        num_parts=10,
        mask_image=False
):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    model.eval()

    adj = torch.from_numpy(coarse_adj_npy).float()
    idx = -1
    for test_dataset_name, test_dataloader, num_query in zip(cfg.DATASETS.TEST_NAMES, test_dataloader_collection, num_query_collection):
        idx += 1
        feats, pids, camids = [], [], []
        if use_mask:
            test_prefetcher = data_prefetcher_mask(test_dataloader)
        else:
            test_prefetcher = data_prefetcher(test_dataloader)
        batch = test_prefetcher.next()
        while batch[0] is not None:
            if use_mask:
                img, mask, pid, camid = batch
                adj_batch = adj.repeat(img.size(0), 1, 1)
            
            with torch.no_grad():
                output = model(img, img, mask, adj_batch)
#                 feat = output[1]
#                 feat = output[3]
                feat = torch.cat([output[1], output[3]], dim=1)
                
            feats.append(feat)
            pids.extend(pid.cpu().numpy())
            camids.extend(np.asarray(camid))

            batch = test_prefetcher.next()

        feats = torch.cat(feats, dim=0)
        if cfg.TEST.NORM:
            feats = F.normalize(feats, p=2, dim=1)
        # query
        qf = feats[:num_query]
        q_pids = np.asarray(pids[:num_query])
        q_camids = np.asarray(camids[:num_query])
        # gallery
        gf = feats[num_query:]
        g_pids = np.asarray(pids[num_query:])
        g_camids = np.asarray(camids[num_query:])

        # cosine distance
        distmat = torch.mm(qf, gf.t()).cpu().numpy()

        # euclidean distance
        # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
        #           torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        # distmat.addmm_(1, -2, qf, gf.t())
        # distmat = distmat.numpy()
        cmc, mAP = evaluate(-distmat, q_pids, g_pids, q_camids, g_camids)
        logger.info(f"Results on {test_dataset_name} : ")
        logger.info(f"mAP: {mAP:.1%}")
        for r in [1, 5, 10]:
            logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
            
        if is_vis:
            query_rand = 10
            topK = 10
            is_save_all = True
            query_rand_idx = range(0, num_query) if is_save_all else random.sample(range(0, num_query), query_rand)
            print(f'|-------- Randomly saving top-{topK} results of {len(query_rand_idx)} queries for {test_dataset_name} --------|')
            
            qf_rand = qf[query_rand_idx]
            q_pids_rand = q_pids[query_rand_idx]
            q_camids_rand = q_camids[query_rand_idx]
            
            q_items = test_collection[idx][:num_query]
            q_items_rand = list()
            for i in query_rand_idx:
                q_items_rand.append(q_items[i])
            g_items = test_collection[idx][num_query:]
            
            distmat_rand = torch.mm(qf_rand, gf.t()).cpu().numpy()
            distmat_rand = -distmat_rand
            
            indices = np.argsort(distmat_rand, axis=1)
            matches = (g_pids[indices] == q_pids_rand[:, np.newaxis]).astype(np.int32)
            
            save_img_size = (256, 256)
            
            if test_dataset_name == 'market1501':
                save_img_size = (128, 256)
            
            for q_idx in range(len(query_rand_idx)):
                savefilename = ''
                # get query pid and camid
                q_path = q_items_rand[q_idx][0]
                q_pid = q_items_rand[q_idx][1]
                q_camid = q_items_rand[q_idx][2]
                
                savefilename += 'q-'+q_path.split('/')[-1]+'_g'

                # remove gallery samples that have the same pid and camid with query
                order = indices[q_idx]
                remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
                keep = np.invert(remove)

                print('Query Path : ', q_path)
                print('Result idx : ', order[:topK])
                
                img_list = list()
                q_img = cv2.imread(q_path)
                q_img = cv2.resize(q_img, save_img_size)
                cv2.rectangle(q_img, (0,0), save_img_size, (255,0,0), 4)
                img_list.append(q_img)
                
                for g_idx in order[:topK]:
                    g_img = cv2.imread(g_items[g_idx][0])
                    g_img = cv2.resize(g_img, save_img_size)
                    if q_pid == g_items[g_idx][1] and q_camid == g_items[g_idx][2]:
                        cv2.rectangle(g_img, (0,0), save_img_size, (255,255,0), 4)
                    elif q_pid == g_items[g_idx][1] and q_camid != g_items[g_idx][2]:
                        cv2.rectangle(g_img, (0,0), save_img_size, (0,255,0), 4)
                    else:
                        cv2.rectangle(g_img, (0,0), save_img_size, (0,0,255), 4)
                    img_list.append(g_img)
                    savefilename += '-'+str(g_items[g_idx][1])
                
                pic = np.concatenate(img_list, 1)
                picsavedir = os.path.join(cfg.OUTPUT_DIR, '-'.join(cfg.DATASETS.TEST_NAMES), 'examples', test_dataset_name)
                if not os.path.exists(picsavedir): os.makedirs(picsavedir)
                savefilepath = os.path.join(picsavedir, savefilename+'.jpg')
                cv2.imwrite(savefilepath, pic)
                print('Save example picture to ', savefilepath)
Exemple #13
0
def inference(cfg, model, train_dataloader, test_dataloader, num_query):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    logger.info("compute precise batchnorm ...")
    # model.train()
    # update_bn_stats(model, train_dataloader, num_iters=300)
    model.eval()

    cat_feats, feats, pids, camids = [], [], [], []
    test_prefetcher = data_prefetcher(test_dataloader)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        cat_feat, feat = model(img)
        cat_feats.append(cat_feat.cpu())
        feats.append(feat.cpu())
        pids.extend(np.asarray(pid.cpu().numpy()))
        camids.extend(np.asarray(camid))

        batch = test_prefetcher.next()

    feats = torch.cat(feats, dim=0)
    cat_feats = torch.cat(cat_feats, dim=0)
    if cfg.TEST.NORM:
        feats = F.normalize(feats, p=2, dim=1)
        cat_feats = F.normalize(cat_feats, p=2, dim=1)
    # query
    cat_qf = cat_feats[:num_query]
    qf = feats[:num_query]
    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])
    # gallery
    cat_gf = cat_feats[num_query:]
    gf = feats[num_query:]
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    # cosine distance
    cat_dist = torch.mm(cat_qf, cat_gf.t())
    distmat = torch.mm(qf, gf.t())

    # IIA post fusion strategy for all query and gallery
    # qf = qf
    # gf = gf
    # m = qf.shape[0]
    # n = gf.shape[0]
    # distmat = torch.zeros((m, n)).to(qf)
    # for i, q_f in enumerate(qf):
    #     print(i)
    #     D = torch.cat([q_f[None, :], gf], dim=0) # [1+g, 2048]
    #     S = torch.mm(D, D.t())  # [1+g, 1+g]
    #     for _ in range(5):
    #         S = S - torch.eye(S.shape[0]).to(S)
    #         s_v, s_i = torch.topk(S, 10, dim=1)
    #         s_v = F.softmax(s_v, dim=1)
    #         s = torch.zeros((S.size()[0], S.size()[0])).to(qf)  # [1+g, 1+g]
    #         for j in range(s_i.shape[0]):
    #             s[j, s_i[j]] = s_v[j]
    #         u = 0.8 * torch.eye(S.size()[0]).to(s) + 0.2 * s
    #         D = torch.mm(u, D)
    #         S = torch.mm(D, D.t())
    #     distmat[i] = S[0][1:]

    cmc, mAP = evaluate(1 - distmat.numpy(), q_pids, g_pids, q_camids,
                        g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")

    cmc, mAP = evaluate(1 - cat_dist.numpy(), q_pids, g_pids, q_camids,
                        g_camids)
    logger.info('cat feature')
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
Exemple #14
0
def inference_aligned(
    cfg,
    model,
    test_dataloader,
    num_query,
):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    model.eval()

    g_feats, l_feats, pids, camids = [], [], [], []
    val_prefetcher = data_prefetcher(test_dataloader)
    batch = val_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        with torch.no_grad():
            g_feat, l_feat = model(img)
            #g_feat, l_feat = model(torch.flip(img, [3])) # better
        g_feats.append(g_feat.data.cpu())
        l_feats.append(l_feat.data.cpu())
        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = val_prefetcher.next()

    g_feats = torch.cat(g_feats, dim=0)
    l_feats = torch.cat(l_feats, dim=0)

    if cfg.TEST.NORM:
        g_feats = F.normalize(g_feats, p=2, dim=1)

    # query
    qf = g_feats[:num_query]
    lqf = l_feats[:num_query]
    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])

    # gallery
    gf = g_feats[num_query:]
    lgf = l_feats[num_query:]
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    # calculate the global distance
    if True:
        logger.info("--------use re-ranking--------")
        lqf = lqf.permute(0, 2, 1)
        lgf = lgf.permute(0, 2, 1)
        local_distmat = low_memory_local_dist(lqf.numpy(),
                                              lgf.numpy(),
                                              aligned=True)
        local_qq_distmat = low_memory_local_dist(lqf.numpy(),
                                                 lqf.numpy(),
                                                 aligned=True)
        local_gg_distmat = low_memory_local_dist(lgf.numpy(),
                                                 lgf.numpy(),
                                                 aligned=True)
        local_dist = np.concatenate([
            np.concatenate([local_qq_distmat, local_distmat], axis=1),
            np.concatenate([local_distmat.T, local_gg_distmat], axis=1)
        ],
                                    axis=0)

        distmat = re_ranking(qf,
                             gf,
                             k1=6,
                             k2=2,
                             lambda_value=0.3,
                             local_distmat=local_dist,
                             theta_value=0.5,
                             only_local=False)
        ## theta hyer-patameters
        # for theta in np.arange(0,1.1,0.1):
        #     distmat = re_ranking(qf,gf,k1=6,k2=2,lambda_value=0.3,local_distmat=local_dist,theta=theta,only_local=False)
        #     cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
        #     logger.info(f"mAP: {mAP:.1%}")
        #     for r in [1, 5, 10]:
        #         logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
        #     logger.info("Theta:{}; Score: {}".format(theta, (mAP+cmc[0])/2.))

    #score = distmat
    #index = np.argsort(score, axis=1)  # from small to large

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
    logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")
print('Cython is {} times faster than python\n'.format(pytime / cytime))

print("=> Check precision")

num_q = 30
num_g = 300
max_rank = 5
distmat = np.random.rand(num_q, num_g) * 20
q_pids = np.random.randint(0, num_q, size=num_q)
g_pids = np.random.randint(0, num_g, size=num_g)
q_camids = np.random.randint(0, 5, size=num_q)
g_camids = np.random.randint(0, 5, size=num_g)

cmc, t_cmc, mAP, t_mAP = evaluate(distmat,
                                  q_pids,
                                  g_pids,
                                  q_camids,
                                  g_camids,
                                  max_rank,
                                  use_cython=False)
print("Python:\nmAP = {} \ncmc = {} \nt_cmc={} \nt_mAP={}\n".format(
    mAP, cmc, t_cmc, t_mAP))
cmc, mAP = evaluate(distmat,
                    q_pids,
                    g_pids,
                    q_camids,
                    g_camids,
                    max_rank,
                    use_cython=True)
print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
Exemple #16
0
def inference_flipped(cfg,
                      model,
                      test_dataloader,
                      num_query,
                      use_re_ranking=True,
                      distance_metric='global_local'):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    model.eval()

    g_feats, l_feats, gf_feats, lf_feats, pids, camids = [], [], [], [], [], []
    val_prefetcher = data_prefetcher(test_dataloader)
    batch = val_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch

        with torch.no_grad():
            g_feat, l_feat = model(img)
            gf_feat, lf_feat = model(torch.flip(img, [3]))

        g_feats.append(g_feat.data.cpu())
        l_feats.append(l_feat.data.cpu())
        gf_feats.append(gf_feat.data.cpu())
        lf_feats.append(lf_feat.data.cpu())

        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = val_prefetcher.next()

    g_feats = torch.cat(g_feats, dim=0)
    l_feats = torch.cat(l_feats, dim=0)
    gf_feats = torch.cat(gf_feats, dim=0)
    lf_feats = torch.cat(lf_feats, dim=0)

    if cfg.TEST.NORM:
        g_feats = F.normalize(g_feats, p=2, dim=1)
        gf_feats = F.normalize(gf_feats, p=2, dim=1)

    # query
    qf = g_feats[:num_query]
    lqf = l_feats[:num_query]
    qff = gf_feats[:num_query]
    lqff = lf_feats[:num_query]
    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])

    # gallery
    gf = g_feats[num_query:]
    lgf = l_feats[num_query:]
    gff = gf_feats[num_query:]
    lgff = lf_feats[num_query:]
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    # calculate the global distance
    if not use_re_ranking:
        m, n = qf.shape[0], gf.shape[0]
        global_distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                         torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        global_distmat.addmm_(1, -2, qf, gf.t())
        global_distmat = global_distmat.numpy()

        # calculate the local distance
        lqf = lqf.permute(0, 2, 1)
        lgf = lgf.permute(0, 2, 1)
        local_distmat = low_memory_local_dist(lqf.numpy(),
                                              lgf.numpy(),
                                              aligned=True)

        if distance_metric == 'global':
            logger.info("--------use global features--------")
            distmat = global_distmat
        elif distance_metric == 'local':
            logger.info("--------use local features--------")
            distmat = local_distmat
        elif distance_metric == 'global_local':
            logger.info("--------use global and local features--------")
            distmat = global_distmat + local_distmat
    else:
        logger.info("--------use re-ranking--------")
        lqf = lqf.permute(0, 2, 1)
        lgf = lgf.permute(0, 2, 1)
        local_distmat = low_memory_local_dist(lqf.numpy(),
                                              lgf.numpy(),
                                              aligned=True)
        local_qq_distmat = low_memory_local_dist(lqf.numpy(),
                                                 lqf.numpy(),
                                                 aligned=True)
        local_gg_distmat = low_memory_local_dist(lgf.numpy(),
                                                 lgf.numpy(),
                                                 aligned=True)
        local_dist = np.concatenate([
            np.concatenate([local_qq_distmat, local_distmat], axis=1),
            np.concatenate([local_distmat.T, local_gg_distmat], axis=1)
        ],
                                    axis=0)

        logger.info("--------use re-ranking flipped--------")
        lqff = lqff.permute(0, 2, 1)
        lgff = lgff.permute(0, 2, 1)
        local_distmat = low_memory_local_dist(lqff.numpy(),
                                              lgff.numpy(),
                                              aligned=True)
        local_qq_distmat = low_memory_local_dist(lqff.numpy(),
                                                 lqff.numpy(),
                                                 aligned=True)
        local_gg_distmat = low_memory_local_dist(lgff.numpy(),
                                                 lgff.numpy(),
                                                 aligned=True)
        local_dist_flipped = np.concatenate([
            np.concatenate([local_qq_distmat, local_distmat], axis=1),
            np.concatenate([local_distmat.T, local_gg_distmat], axis=1)
        ],
                                            axis=0)

        # for theta in np.arange(0.0,1.0,0.05):
        #     distmat = re_ranking(qf,gf,k1=6,k2=2,lambda_value=0.3,local_distmat=local_dist,theta=theta,only_local=False)
        #     distmat_flip = re_ranking(qff,gff,k1=6,k2=2,lambda_value=0.3,local_distmat=local_dist_flipped,theta=theta,only_local=False)
        #     distmat = distmat + distmat_flip
        #     cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
        #     logger.info(f"mAP: {mAP:.1%}")
        #     for r in [1, 5, 10]:
        #         logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
        #     logger.info("Theta:{}; Score: {}".format(theta, (mAP+cmc[0])/2.))

        theta = 0.45
        distmat = re_ranking(qf,
                             gf,
                             k1=6,
                             k2=2,
                             lambda_value=0.3,
                             local_distmat=local_dist,
                             theta_value=theta,
                             only_local=False)
        distmat_flip = re_ranking(qff,
                                  gff,
                                  k1=6,
                                  k2=2,
                                  lambda_value=0.3,
                                  local_distmat=local_dist_flipped,
                                  theta_value=theta,
                                  only_local=False)

        distmat = (distmat + distmat_flip) / 2

    score = distmat
    index = np.argsort(score, axis=1)  # from small to large

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
    logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")
def inference(cfg, model, test_dataloader, num_query):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    model.eval()

    feats, pids, camids = [], [], []
    #local_feats = []
    test_prefetcher = data_prefetcher(test_dataloader, cfg)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        with torch.no_grad():
            feat = model(img)
            #feat = model(torch.flip(img, [3]))

        if isinstance(feat, tuple):
            feats.append(feat[0])
            #local_feats.append(feat[1])
        else:
            feats.append(feat)
        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = test_prefetcher.next()

    feats = torch.cat(feats, dim=0)
    if cfg.TEST.NORM:
        feats = F.normalize(feats, p=2, dim=1)

    # query
    qf = feats[:num_query]

    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])

    # gallery
    gf = feats[num_query:]
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    local_distmat = None
    # use reranking
    logger.info("use reranking")
    #distmat = re_ranking(qf, gf, k1=14, k2=4, lambda_value=0.4)

    search_param = False
    if search_param:
        best_score = 0
        best_param = []
        for k1 in range(5, 9):
            for k2 in range(1, k1):
                for l in np.linspace(0, 0.5, 11):
                    distmat = re_ranking(qf, gf, k1=k1, k2=k2, lambda_value=l)
                    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids,
                                        g_camids)
                    score = (cmc[0] + mAP) / 2
                    #logger.info(f"mAP: {mAP:.1%}")
                    print('k1, k2, l', k1, k2, np.around(l, 2),
                          'r1, mAP, score', np.around(cmc[0], 4),
                          np.around(mAP, 4), np.around(score, 4))
                    if score > best_score:
                        best_score = score
                        best_param = [k1, k2, l]
        print('Best Param', best_param)
        distmat = re_ranking(qf,
                             gf,
                             k1=best_param[0],
                             k2=best_param[1],
                             lambda_value=best_param[2],
                             local_distmat=local_distmat,
                             only_local=False)
    else:
        distmat = re_ranking(qf,
                             gf,
                             k1=6,
                             k2=2,
                             lambda_value=0.3,
                             local_distmat=local_distmat,
                             only_local=False,
                             theta_value=0.9)  #(current best)
        #distmat = re_ranking(qf, gf, k1=6, k2=2, lambda_value=0.4) # try

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
    logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")
Exemple #18
0
def inference_flipped(cfg, model, test_dataloader, num_query):
    logger = logging.getLogger("reid_baseline.inference")
    logger.info("Start inferencing")

    model.eval()

    feats, feats_flipped, pids, camids = [], [], [], []
    local_feats, local_feats_flipped = [], []
    test_prefetcher = data_prefetcher(test_dataloader)
    batch = test_prefetcher.next()
    while batch[0] is not None:
        img, pid, camid = batch
        with torch.no_grad():
            feat = model(img)
            feat_flipped = model(torch.flip(img, [3]))

        if isinstance(feat, tuple):
            feats.append(feat[0])
            local_feats.append(feat[1])

            feats_flipped.append(feat_flipped[0])
            local_feats_flipped.append(feat_flipped[1])
        else:
            feats.append(feat)
            feats_flipped.append(feat_flipped)

        pids.extend(pid.cpu().numpy())
        camids.extend(np.asarray(camid))

        batch = test_prefetcher.next()

    feats = torch.cat(feats, dim=0)
    feats_flipped = torch.cat(feats_flipped, dim=0)
    if len(local_feats) > 0:
        local_feats = torch.cat(local_feats, dim=0)
        local_feats_flipped = torch.cat(local_feats_flipped, dim=0)

    #print('feats_flipped', len(feats_flipped), feats_flipped[0])
    if cfg.TEST.NORM:
        feats = F.normalize(feats, p=2, dim=1)
        feats_flipped = F.normalize(feats_flipped, p=2, dim=1)

    # query
    qf = feats[:num_query]
    qf_flipped = feats_flipped[:num_query]
    if len(local_feats) > 0:
        lqf = local_feats[:num_query]
        lqf_flipped = local_feats_flipped[:num_query]

    q_pids = np.asarray(pids[:num_query])
    q_camids = np.asarray(camids[:num_query])
    # gallery
    gf = feats[num_query:]
    gf_flipped = feats_flipped[num_query:]
    if len(local_feats) > 0:
        lgf = local_feats[num_query:]
        lgf_flipped = local_feats_flipped[num_query:]
    g_pids = np.asarray(pids[num_query:])
    g_camids = np.asarray(camids[num_query:])

    # cosine distance
    #distmat = torch.mm(qf, gf.t()).cpu().numpy()

    if len(local_feats) > 0:
        #if True:
        # calculate the local distance
        lqf = lqf.permute(0, 2, 1)
        lgf = lgf.permute(0, 2, 1)
        local_qg_distmat = low_memory_local_dist(lqf.cpu().numpy(),
                                                 lgf.cpu().numpy(),
                                                 aligned=True)
        local_qq_distmat = low_memory_local_dist(lqf.cpu().numpy(),
                                                 lqf.cpu().numpy(),
                                                 aligned=True)
        local_gg_distmat = low_memory_local_dist(lgf.cpu().numpy(),
                                                 lgf.cpu().numpy(),
                                                 aligned=True)
        local_distmat = np.concatenate([
            np.concatenate([local_qq_distmat, local_qg_distmat], axis=1),
            np.concatenate([local_qg_distmat.T, local_gg_distmat], axis=1)
        ],
                                       axis=0)

        # flipped
        lqf = lqf_flipped.permute(0, 2, 1)
        lgf = lgf_flipped.permute(0, 2, 1)
        local_qg_distmat = low_memory_local_dist(lqf.cpu().numpy(),
                                                 lgf.cpu().numpy(),
                                                 aligned=True)
        local_qq_distmat = low_memory_local_dist(lqf.cpu().numpy(),
                                                 lqf.cpu().numpy(),
                                                 aligned=True)
        local_gg_distmat = low_memory_local_dist(lgf.cpu().numpy(),
                                                 lgf.cpu().numpy(),
                                                 aligned=True)
        local_distmat_flipped = np.concatenate([
            np.concatenate([local_qq_distmat, local_qg_distmat], axis=1),
            np.concatenate([local_qg_distmat.T, local_gg_distmat], axis=1)
        ],
                                               axis=0)

    else:
        local_distmat = None

    # use reranking
    logger.info("use reranking")
    #distmat = re_ranking(qf, gf, k1=14, k2=4, lambda_value=0.4)

    search_theta = True
    if search_theta:
        best_score = 0
        #for theta in np.linspace(0.9, 1.0, 11):
        for theta in np.linspace(0, 1.0, 21):
            distmat = re_ranking(qf,
                                 gf,
                                 k1=6,
                                 k2=2,
                                 lambda_value=0.3,
                                 local_distmat=local_distmat,
                                 theta_value=theta,
                                 only_local=False)  # (current best)

            distmat_flipped = re_ranking(qf_flipped,
                                         gf_flipped,
                                         k1=6,
                                         k2=2,
                                         lambda_value=0.3,
                                         local_distmat=local_distmat_flipped,
                                         theta_value=theta,
                                         only_local=False)  # (current best)

            distmat = (distmat + distmat_flipped) / 2
            #cmc, mAP = evaluate(distmat + distmat_flipped, q_pids, g_pids, q_camids, g_camids)
            cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
            score = (cmc[0] + mAP) / 2
            print('theta', np.around(theta, 2), 'r1, mAP, score',
                  np.around(cmc[0], 4), np.around(mAP, 4), np.around(score, 4))
            if score > best_score:
                best_score = score
                best_param = theta
                best_distmat = distmat

            # saving
            strtime = time.strftime("%Y%m%d_%H%M%S", time.localtime())
            if abs(theta - 0.95) < 1e-4:
                # saving dist_mats
                f = h5py.File(
                    'dist_mats/val_%s_%s_t0.95_flip.h5' %
                    (cfg.MODEL.NAME, strtime), 'w')
                f.create_dataset('dist_mat', data=distmat, compression='gzip')
                f.close()

        print('Best Param', best_param)
        distmat = best_distmat
    else:
        distmat = re_ranking(qf,
                             gf,
                             k1=6,
                             k2=2,
                             lambda_value=0.3,
                             local_distmat=local_distmat,
                             only_local=False,
                             theta_value=0.95)  #(current best)
        distmat_flipped = re_ranking(qf_flipped,
                                     gf_flipped,
                                     k1=6,
                                     k2=2,
                                     lambda_value=0.3,
                                     local_distmat=local_distmat_flipped,
                                     theta_value=0.95,
                                     only_local=False)  # (current best)
        distmat = (distmat + distmat_flipped) / 2

    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)
    logger.info(f"mAP: {mAP:.1%}")
    for r in [1, 5, 10]:
        logger.info(f"CMC curve, Rank-{r:<3}:{cmc[r - 1]:.1%}")
    logger.info(f"Score: {(mAP + cmc[0]) / 2.:.1%}")