示例#1
0
def mgn_euclidean_dist(x, y, norm=True, weights=[1, 1, 1, 1, 1, 1, 1, 1]):
    distmat = None

    for i in range(8):
        if norm:
            dist = euclidean_dist(
                F.normalize(x[:, i * 256:(i + 1) * 256], p=2, dim=1),
                F.normalize(y[:, i * 256:(i + 1) * 256], p=2, dim=1))
        else:
            dist = euclidean_dist(x[:, i * 256:(i + 1) * 256],
                                  y[:, i * 256:(i + 1) * 256])
        if distmat is None:
            distmat = dist * weights[i]
        else:
            distmat += dist * weights[i]
    return distmat
示例#2
0
def get_sparse_distmat(all_feature,
                       eps,
                       len_slice=1000,
                       use_gpu=False,
                       dist_k=-1,
                       top_k=35):
    if use_gpu:
        gpu_feature = all_feature.cuda()
    else:
        gpu_feature = all_feature
    n_iter = len(all_feature) // len_slice + int(
        len(all_feature) % len_slice > 0)
    distmats = []
    kdist = []
    with tqdm(total=n_iter) as pbar:
        for i in range(n_iter):
            if use_gpu:
                distmat = euclidean_dist(
                    gpu_feature[i * len_slice:(i + 1) * len_slice],
                    gpu_feature).data.cpu().numpy()
            else:
                distmat = euclidean_dist(
                    gpu_feature[i * len_slice:(i + 1) * len_slice],
                    gpu_feature).numpy()

            if dist_k > 0:
                dist_rank = np.argpartition(distmat, range(1,
                                                           dist_k + 1))  # 1,N
                for j in range(distmat.shape[0]):
                    kdist.append(distmat[j, dist_rank[j, dist_k]])
            if 0:
                initial_rank = np.argpartition(distmat, top_k)  # 1,N
                for j in range(distmat.shape[0]):
                    distmat[j, initial_rank[j, top_k:]] = 0
            else:
                distmat[distmat > eps] = 0
            distmats.append(sparse.csr_matrix(distmat))

            pbar.update(1)
    if dist_k > 0:
        return sparse.vstack(distmats), kdist

    return sparse.vstack(distmats)
    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _ = batch
                data = data.cuda()

                # ff = torch.FloatTensor(data.size(0), 2048).zero_()
                # for i in range(2):
                #     if i == 1:
                #         data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda'))
                #     outputs = self.model(data)
                #     f = outputs.data.cpu()
                #     ff = ff + f

                ff = self.model(data).data.cpu()
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))

                feats.append(ff)
                pids.append(pid)
                camids.append(camid)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)

        query_feat = feats[:num_query]
        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_feat = feats[num_query:]
        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(
            distmat.numpy(),
            query_pid.numpy(),
            gallery_pid.numpy(),
            query_camid.numpy(),
            gallery_camid.numpy(),
        )
        self.logger.info('Validation Result:')
        self.logger.info('mAP: {:.2%}'.format(mAP))
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))

        self.logger.info('average of mAP and rank1: {:.2%}'.format(
            (mAP + cmc[0]) / 2.0))

        self.logger.info('-' * 20)
示例#4
0
    def get_result(self, show):
        """

        :param show: 是否显示查询出的结果
        :return None
        """
        tbar = tqdm.tqdm(self.test_dataloader)
        features_all, names_all = [], []
        with torch.no_grad():
            for i, (images, names) in enumerate(tbar):
                # 完成网络的前向传播
                # features = self.solver.forward((images, torch.zeros((images.size(0)))))[-1]
                features = self.solver.tta(
                    (images, torch.zeros((images.size(0)))))

                features_all.append(features.detach().cpu())
                names_all.extend(names)

        features_all = torch.cat(features_all, dim=0)
        query_features = features_all[:self.num_query]
        gallery_features = features_all[self.num_query:]

        query_names = np.array(names_all[:self.num_query])
        gallery_names = np.array(names_all[self.num_query:])

        if self.dist == 're_rank':
            distmat = re_rank(query_features, gallery_features)
        elif self.dist == 'cos_dist':
            distmat = cos_dist(query_features, gallery_features)
        elif self.dist == 'euclidean_dist':
            distmat = euclidean_dist(query_features, gallery_features)
        else:
            assert "Not implemented :{}".format(self.dist)

        result = {}
        for query_index, query_dist in enumerate(distmat):
            choose_index = np.argsort(query_dist)[:self.num_choose]
            query_name = query_names[query_index]
            gallery_name = gallery_names[choose_index]
            result[query_name] = gallery_name.tolist()
            if query_name in self.demo_names:
                self.show_result(query_name, gallery_name, 5, show)

        with codecs.open('./result.json', 'w', "utf-8") as json_file:
            json.dump(result, json_file, ensure_ascii=False)
    def validation(self, valid_loader):
        """ 完成模型的验证过程

        :param valid_loader: 验证集的Dataloader
        :return rank1: rank1得分;类型为float
        :return mAP: 平均检索精度;类型为float
        :return average_score: 平均得分;类型为float
        """
        self.model.eval()
        tbar = tqdm.tqdm(valid_loader)
        features_all, labels_all = [], []
        with torch.no_grad():
            for i, (images, labels, paths) in enumerate(tbar):
                # 完成网络的前向传播
                # features = self.solver.forward((images, labels))[-1]
                features = self.solver.tta((images, labels))
                features_all.append(features.detach().cpu())
                labels_all.append(labels)

        features_all = torch.cat(features_all, dim=0)
        labels_all = torch.cat(labels_all, dim=0)

        query_features = features_all[:self.num_query]
        query_labels = labels_all[:self.num_query]

        gallery_features = features_all[self.num_query:]
        gallery_labels = labels_all[self.num_query:]

        if self.dist == 're_rank':
            distmat = re_rank(query_features, gallery_features)
        elif self.dist == 'cos_dist':
            distmat = cos_dist(query_features, gallery_features)
        elif self.dist == 'euclidean_dist':
            distmat = euclidean_dist(query_features, gallery_features)
        else:
            assert "Not implemented :{}".format(self.dist)

        all_rank_precison, mAP, _ = eval_func(distmat, query_labels.numpy(), gallery_labels.numpy(),
                                              use_cython=self.cython)

        rank1 = all_rank_precison[0]
        average_score = 0.5 * rank1 + 0.5 * mAP
        print('Rank1: {:.2%}, mAP {:.2%}, average score {:.2%}'.format(rank1, mAP, average_score))
        return rank1, mAP, average_score
示例#6
0
    def get_result(self, show):
        """

        :param show: 是否显示查询出的结果
        :return None
        """
        tbar = tqdm.tqdm(self.valid_dataloader)
        features_all, labels_all, paths_all = [], [], []
        with torch.no_grad():
            for i, (images, labels, paths) in enumerate(tbar):
                # 完成网络的前向传播
                # features = self.solver.forward(images)[-1]
                features = self.solver.tta(images)

                features_all.append(features.detach().cpu())
                labels_all.extend(labels)
                paths_all.extend(paths)

        features_all = torch.cat(features_all, dim=0)
        query_features = features_all[:self.num_query]
        gallery_features = features_all[self.num_query:]

        query_lables = np.array(labels_all[:self.num_query])
        gallery_labels = np.array(labels_all[self.num_query:])

        query_paths = np.array(paths_all[:self.num_query])
        gallery_paths = np.array(paths_all[self.num_query:])

        if self.dist == 're_rank':
            distmat = re_rank(query_features, gallery_features)
        elif self.dist == 'cos_dist':
            distmat = cos_dist(query_features, gallery_features)
        elif self.dist == 'euclidean_dist':
            distmat = euclidean_dist(query_features, gallery_features)
        else:
            assert "Not implemented :{}".format(self.dist)

        for query_index, query_dist in enumerate(distmat):
            choose_index = np.argsort(query_dist)[:self.num_choose]
            query_path = query_paths[query_index]
            gallery_path = gallery_paths[choose_index]
            query_label = query_lables[query_index]
            gallery_label = gallery_labels[choose_index]
            self.show_result(query_path, gallery_path, query_label, gallery_label, 5, show)
示例#7
0
    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _ = batch
                data = data.cuda()
                local_feat_list = self.model(data)
                feat = torch.cat([lf.data.cpu() for lf in local_feat_list],
                                 dim=1)
                feats.append(feat)
                pids.append(pid)
                camids.append(camid)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)

        query_feat = feats[:num_query]
        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_feat = feats[num_query:]
        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(distmat.numpy(),
                                query_pid.numpy(),
                                gallery_pid.numpy(),
                                query_camid.numpy(),
                                gallery_camid.numpy(),
                                use_cython=self.cfg.SOLVER.CYTHON)
        self.logger.info('Validation Result:')
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
        self.logger.info('mAP: {:.2%}'.format(mAP))
        self.logger.info('-' * 20)
示例#8
0
def re_ranking_batch(all_feature, q_num, k1, k2, lambda_value, len_slice=1000):

    # calculate (q+g)*(q+g)
    initial_rank = np.zeros((len(all_feature), k1 + 1)).astype(np.int32)

    original_dist = np.zeros((q_num, len(all_feature)))

    s_time = time.time()

    n_iter = len(all_feature) // len_slice + int(
        len(all_feature) % len_slice > 0)

    with tqdm(total=n_iter) as pbar:
        for i in range(n_iter):
            dis_i_qg = euclidean_dist(
                all_feature[i * len_slice:(i + 1) * len_slice],
                all_feature).data.cpu().numpy()
            initial_i_rank = np.argpartition(
                dis_i_qg,
                range(1, k1 + 1),
            ).astype(np.int32)[:, :k1 + 1]
            initial_rank[i * len_slice:(i + 1) * len_slice] = initial_i_rank
            pbar.update(1)
    # print(initial_rank[0])

    end_time = time.time()
    print("rank time : %s" % (end_time - s_time))

    all_V = []

    s_time = time.time()

    n_iter = len(all_feature) // len_slice + int(
        len(all_feature) % len_slice > 0)

    with tqdm(total=n_iter) as pbar:
        for i in range(n_iter):
            dis_i_qg = euclidean_dist(
                all_feature[i * len_slice:(i + 1) * len_slice],
                all_feature).data.cpu().numpy()
            for ks in range(dis_i_qg.shape[0]):
                r_k = i * len_slice + ks
                dis_i_qg[ks] = np.power(dis_i_qg[ks], 2).astype(np.float32)
                dis_i_qg[ks] = 1. * dis_i_qg[ks] / np.max(dis_i_qg[ks])
                if r_k < q_num:
                    original_dist[r_k] = dis_i_qg[ks]
                V, k_reciprocal_expansion_index, weight = calculate_V(
                    initial_rank, len(all_feature), dis_i_qg[ks], r_k, k1)
                # if r_k == 0:
                #     print(k_reciprocal_expansion_index)
                #     print(weight)
                #     print(dis_i_qg[ks])
                all_V.append(sparse.csr_matrix(V))

            pbar.update(1)

    all_V = sparse.vstack(all_V)
    # print(all_V.getrow(0).toarray())
    end_time = time.time()
    print("calculate V time : %s" % (end_time - s_time))
    # print(all_V.todense()[0])

    all_V_qe = []
    s_time = time.time()
    for i in range(len(all_feature)):
        temp_V = np.zeros((k2, len(all_feature)))
        for l, row_index in enumerate(initial_rank[i, :k2]):
            temp_V[l, :] = all_V.getrow(row_index).toarray()[0]

        V_qe = np.mean(temp_V, axis=0)
        all_V_qe.append(sparse.csr_matrix(V_qe))
    all_V_qe = sparse.vstack(all_V_qe)
    # print(all_V_qe.todense()[0])
    del all_V
    end_time = time.time()
    print("calculate V_qe time : %s" % (end_time - s_time))

    invIndex = []
    for i in range(len(all_feature)):
        invIndex.append(
            np.where(all_V_qe.getcol(i).toarray().transpose()[0] != 0)[0])
    jaccard_dist = np.zeros_like(original_dist, dtype=np.float32)

    for i in range(q_num):
        temp_min = np.zeros(shape=[1, len(all_feature)], dtype=np.float32)

        indNonZero = np.where(all_V_qe.getrow(i).toarray()[0] != 0)[0]

        indImages = []
        indImages = [invIndex[ind] for ind in indNonZero]
        # print(indImages)
        for j in range(len(indNonZero)):
            # print(indNonZero[j])
            c = all_V_qe.getrow(i).getcol(indNonZero[j]).toarray()[0, 0]
            # print(c)
            # print(indImages[j])

            t_min = np.zeros((indImages[j].shape[0]))
            for kk in range(indImages[j].shape[0]):
                temp_d = all_V_qe.getrow(indImages[j][kk]).getcol(
                    indNonZero[j]).toarray()[0, 0]
                t_min[kk] = np.minimum(c, temp_d)
            # print(t_min)

            temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + t_min
            # temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
            #                                                                    V[indImages[j], indNonZero[j]])
        jaccard_dist[i] = 1 - temp_min / (2. - temp_min)
    # print(jaccard_dist[0])
    # print(original_dist[0])
    final_dist = jaccard_dist * (1 -
                                 lambda_value) + original_dist * lambda_value
    del original_dist
    del all_V_qe
    del jaccard_dist
    final_dist = final_dist[:q_num, q_num:]
    return final_dist
示例#9
0
    def evaluate(self):
        self.model.eval()
        num_query = self.num_query
        feats, pids, camids = [], [], []
        histlabels = []
        histpreds = []
        with torch.no_grad():
            for batch in tqdm(self.val_dl, total=len(self.val_dl),
                              leave=False):
                data, pid, camid, _, histlabel = batch
                data = data.cuda()
                # histlabel = histlabel.cuda()

                # ff = torch.FloatTensor(data.size(0), 2048).zero_()
                # for i in range(2):
                #     if i == 1:
                #         data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda'))
                #     outputs = self.model(data)
                #     f = outputs.data.cpu()
                #     ff = ff + f

                ff, histpred = self.model(data,
                                          output_feature='with_histlabel')
                ff = ff.data.cpu()
                histpred = histpred.data.cpu()
                fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                ff = ff.div(fnorm.expand_as(ff))

                feats.append(ff)
                pids.append(pid)
                camids.append(camid)
                histlabels.append(histlabel)
                histpreds.append(histpred)
        feats = torch.cat(feats, dim=0)
        pids = torch.cat(pids, dim=0)
        camids = torch.cat(camids, dim=0)
        histpreds = torch.cat(histpreds, dim=0)
        histlabels = torch.cat(histlabels, dim=0)

        hist_acc = (histpreds[:histlabels.size()[0]].max(1)[1] == histlabels
                    ).float().mean().item()

        if self.cfg.TEST.RANDOMPERM <= 0:
            query_feat = feats[:num_query]
            query_pid = pids[:num_query]
            query_camid = camids[:num_query]

            gallery_feat = feats[num_query:]
            gallery_pid = pids[num_query:]
            gallery_camid = camids[num_query:]

            distmat = euclidean_dist(query_feat, gallery_feat)

            cmc, mAP, _ = eval_func(
                distmat.numpy(),
                query_pid.numpy(),
                gallery_pid.numpy(),
                query_camid.numpy(),
                gallery_camid.numpy(),
            )
        else:
            cmc = 0
            mAP = 0
            seed = torch.random.get_rng_state()
            torch.manual_seed(0)
            for i in range(self.cfg.TEST.RANDOMPERM):
                index = torch.randperm(feats.size()[0])
                # print(index[:10])
                query_feat = feats[index][:num_query]
                query_pid = pids[index][:num_query]
                query_camid = camids[index][:num_query]

                gallery_feat = feats[index][num_query:]
                gallery_pid = pids[index][num_query:]
                gallery_camid = camids[index][num_query:]

                distmat = euclidean_dist(query_feat, gallery_feat)

                _cmc, _mAP, _ = eval_func(
                    distmat.numpy(),
                    query_pid.numpy(),
                    gallery_pid.numpy(),
                    query_camid.numpy(),
                    gallery_camid.numpy(),
                )
                cmc += _cmc / self.cfg.TEST.RANDOMPERM
                mAP += _mAP / self.cfg.TEST.RANDOMPERM
            torch.random.set_rng_state(seed)

        self.logger.info('Validation Result:')
        self.logger.info('hist acc:{:.2%}'.format(hist_acc))
        self.logger.info('mAP: {:.2%}'.format(mAP))
        for r in self.cfg.TEST.CMC:
            self.logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))

        self.logger.info('average of mAP and rank1: {:.2%}'.format(
            (mAP + cmc[0]) / 2.0))
        self.logger.info('-' * 20)

        if self.summary_writer:
            self.summary_writer.add_scalar('Valid/hist_acc', hist_acc,
                                           self.train_epoch)
            self.summary_writer.add_scalar('Valid/rank1', cmc[0],
                                           self.train_epoch)
            self.summary_writer.add_scalar('Valid/mAP', mAP, self.train_epoch)
            self.summary_writer.add_scalar('Valid/rank1_mAP',
                                           (mAP + cmc[0]) / 2.0,
                                           self.train_epoch)
示例#10
0
def test(args):
    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    logger = setup_logger('reid_baseline.eval', cfg.OUTPUT_DIR, 0, train=False)

    logger.info('Running with config:\n{}'.format(cfg))

    _, val_dl, num_query, num_classes = make_dataloader(cfg)

    model = build_model(cfg, num_classes)
    if cfg.TEST.MULTI_GPU:
        model = nn.DataParallel(model)
        model = convert_model(model)
        logger.info('Use multi gpu to inference')
    para_dict = torch.load(cfg.TEST.WEIGHT)
    model.load_state_dict(para_dict)
    model.cuda()
    model.eval()

    feats, pids, camids, paths = [], [], [], []
    with torch.no_grad():
        for batch in tqdm(val_dl, total=len(val_dl), leave=False):
            data, pid, camid, path = batch
            paths.extend(list(path))
            data = data.cuda()
            feat = model(data).detach().cpu()
            feats.append(feat)
            pids.append(pid)
            camids.append(camid)
    feats = torch.cat(feats, dim=0)
    pids = torch.cat(pids, dim=0)
    camids = torch.cat(camids, dim=0)

    query_feat = feats[:num_query]
    query_pid = pids[:num_query]
    query_camid = camids[:num_query]
    query_path = np.array(paths[:num_query])

    gallery_feat = feats[num_query:]
    gallery_pid = pids[num_query:]
    gallery_camid = camids[num_query:]
    gallery_path = np.array(paths[num_query:])

    distmat = euclidean_dist(query_feat, gallery_feat)

    cmc, mAP, all_AP = eval_func(distmat.numpy(),
                                 query_pid.numpy(),
                                 gallery_pid.numpy(),
                                 query_camid.numpy(),
                                 gallery_camid.numpy(),
                                 use_cython=True)

    if cfg.TEST.VIS:
        worst_q = np.argsort(all_AP)[:cfg.TEST.VIS_Q_NUM]
        qid = query_pid[worst_q]
        q_im = query_path[worst_q]

        ind = np.argsort(distmat, axis=1)
        gid = gallery_pid[ind[worst_q]][..., :cfg.TEST.VIS_G_NUM]
        g_im = gallery_path[ind[worst_q]][..., :cfg.TEST.VIS_G_NUM]

        for idx in range(cfg.TEST.VIS_Q_NUM):
            sid = qid[idx] == gid[idx]
            im = rank_list_to_im(range(len(g_im[idx])), sid, q_im[idx],
                                 g_im[idx])

            im.save(
                osp.join(cfg.OUTPUT_DIR,
                         'worst_query_{}.jpg'.format(str(idx).zfill(2))))

    logger.info('Validation Result:')
    for r in cfg.TEST.CMC:
        logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    logger.info('mAP: {:.2%}'.format(mAP))
    logger.info('-' * 20)

    if not cfg.TEST.RERANK:
        return

    distmat = re_rank(query_feat, gallery_feat)
    cmc, mAP, all_AP = eval_func(distmat,
                                 query_pid.numpy(),
                                 gallery_pid.numpy(),
                                 query_camid.numpy(),
                                 gallery_camid.numpy(),
                                 use_cython=True)

    logger.info('ReRanking Result:')
    for r in cfg.TEST.CMC:
        logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    logger.info('mAP: {:.2%}'.format(mAP))
    logger.info('-' * 20)
示例#11
0
def inference_val(model,  transform, batch_size, feature_dim, k1=20, k2=6, p=0.3, use_rerank=False):
    q_img_list = os.listdir(r'E:\data\reid\dataset7\query')
    query_list = list()
    qid_list = list()
    qcid_list = list()
    for q_img in q_img_list:
        query_list.append(os.path.join(r'E:\data\reid\dataset7\query', q_img))
        qid_list.append(int(q_img.strip(".png").split("_")[0]))
        qcid_list.append(int(q_img.strip(".png").split("_")[1].strip("c")))

    g_img_list = os.listdir(r'E:\data\reid\dataset7\gallery')
    gallery_list = list()
    gid_list = list()
    gcid_list = list()
    for g_img in g_img_list:
        gallery_list.append(os.path.join(r'E:\data\reid\dataset7\gallery', g_img))
        gid_list.append(int(g_img.strip(".png").split("_")[0]))
        gcid_list.append(int(g_img.strip(".png").split("_")[1].strip("c")))
    img_list = list()
    for q_img in query_list:
        q_img = read_image(q_img)
        q_img = transform(q_img)
        img_list.append(q_img)
    for g_img in gallery_list:
        g_img = read_image(g_img)
        g_img = transform(g_img)
        img_list.append(g_img)
    query_num = len(query_list)
    img_data = torch.Tensor([t.numpy() for t in img_list])

    model = model.to(device)
    model.eval()
    iter_n = len(img_list) // batch_size
    if len(img_list) % batch_size != 0:
        iter_n += 1
    all_feature = list()
    for i in range(iter_n):
        # print("batch ----%d----" % (i))
        batch_data = img_data[i * batch_size:(i + 1) * batch_size]
        with torch.no_grad():
            # batch_feature = model(batch_data).detach().cpu()

            ff = torch.FloatTensor(batch_data.size(0), 2048).zero_()
            for i in range(2):
                if i == 1:
                    batch_data = batch_data.index_select(3, torch.arange(batch_data.size(3) - 1, -1, -1).long())
                outputs = model(batch_data)
                f = outputs.data.cpu()
                ff = ff + f

            fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
            ff = ff.div(fnorm.expand_as(ff))

            all_feature.append(ff)
    all_feature = torch.cat(all_feature)
    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]
    if use_rerank:
        distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
    else:
        distmat = euclidean_dist(query_feat, gallery_feat)




    # distmat = euclidean_dist(query_feat, gallery_feat)
    cmc, mAP, _ = eval_func(distmat, np.array(qid_list), np.array(gid_list),
              np.array(qcid_list), np.array(gcid_list))
    print('Validation Result:')
    print(str(k1) + "  -  " + str(k2) + "  -  " + str(p))
    for r in [1, 5, 10]:

        print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    print('mAP: {:.2%}'.format(mAP))
    with open('re_rank.txt', 'a') as f:
        f.write(str(k1)+"  -  "+str(k2)+"  -  "+str(p) + "\n")
        for r in [1, 5, 10]:
            f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1])+"\n")
        f.write('mAP: {:.2%}'.format(mAP) + "\n")
        f.write('------------------------------------------\n')
        f.write('------------------------------------------\n')
        f.write('\n\n')
示例#12
0
def inference_samples(model,  transform, batch_size, feature_dim, k1=20, k2=6, p=0.3, use_rerank=False):
    query_list = list()
    with open(r'E:\data\reid\初赛A榜测试集\初赛A榜测试集/query_a_list.txt', 'r') as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            data = line.split(" ")
            image_name = data[0].split("/")[1]
            img_file = os.path.join(r'E:\data\reid\初赛A榜测试集\初赛A榜测试集\query_a', image_name)
            query_list.append(img_file)

    gallery_list = [os.path.join(r'E:\data\reid\初赛A榜测试集\初赛A榜测试集\gallery_a', x) for x in
                    os.listdir(r'E:\data\reid\初赛A榜测试集\初赛A榜测试集\gallery_a')]
    query_num = len(query_list)
    img_list = list()
    for q_img in query_list:
        q_img = read_image(q_img)
        q_img = transform(q_img)
        img_list.append(q_img)
    for g_img in gallery_list:
        g_img = read_image(g_img)
        g_img = transform(g_img)
        img_list.append(g_img)
    img_data = torch.Tensor([t.numpy() for t in img_list])
    model = model.to(device)
    model.eval()
    iter_n = len(img_list) // batch_size
    if len(img_list) % batch_size != 0:
        iter_n += 1
    all_feature = list()
    for i in range(iter_n):
        print("batch ----%d----" % (i))
        batch_data = img_data[i*batch_size:(i+1)*batch_size]
        with torch.no_grad():

            ff = torch.FloatTensor(batch_data.size(0), feature_dim).zero_()
            for i in range(2):
                if i == 1:
                    batch_data = batch_data.index_select(3, torch.arange(batch_data.size(3) - 1, -1, -1).long())

                outputs= model(batch_data)

                f = outputs.data.cpu()
                ff = ff + f

            fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
            ff = ff.div(fnorm.expand_as(ff))
            all_feature.append(ff)
    all_feature = torch.cat(all_feature)
    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]

    if use_rerank:
        print("use re_rank")
        distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
    else:
        distmat = euclidean_dist(query_feat, gallery_feat)
        distmat = distmat.numpy()
    num_q, num_g = distmat.shape
    indices = np.argsort(distmat, axis=1)

    max_200_indices = indices[:, :200]

    res_dict = dict()
    for q_idx in range(num_q):
        print(query_list[q_idx])
        filename = query_list[q_idx][query_list[q_idx].rindex("\\")+1:]
        max_200_files = [gallery_list[i][gallery_list[i].rindex("\\")+1:] for i in max_200_indices[q_idx]]
        res_dict[filename] = max_200_files

    with open(r'submission_A.json', 'w' ,encoding='utf-8') as f:
        json.dump(res_dict, f)
示例#13
0
def inference_samples(args,
                      model,
                      transform,
                      batch_size,
                      query_txt,
                      query_dir,
                      gallery_dir,
                      save_dir,
                      k1=20,
                      k2=6,
                      p=0.3,
                      use_rerank=False,
                      use_flip=False,
                      max_rank=200,
                      bn_keys=[]):
    print("==>load data info..")
    if query_txt != "":
        query_list = list()
        with open(query_txt, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                data = line.split(" ")
                image_name = data[0].split("/")[1]
                img_file = os.path.join(query_dir, image_name)
                query_list.append(img_file)
    else:
        query_list = [
            os.path.join(query_dir, x) for x in os.listdir(query_dir)
        ]
    gallery_list = [
        os.path.join(gallery_dir, x) for x in os.listdir(gallery_dir)
    ]
    query_num = len(query_list)
    if args.save_fname != '':
        print(query_list[:10])
        query_list = sorted(query_list)
        print(query_list[:10])
        gallery_list = sorted(gallery_list)
    print("==>build dataloader..")
    image_set = ImageDataset(query_list + gallery_list, transform)
    dataloader = DataLoader(image_set,
                            sampler=SequentialSampler(image_set),
                            batch_size=batch_size,
                            num_workers=4)
    bn_dataloader = DataLoader(image_set,
                               sampler=RandomSampler(image_set),
                               batch_size=batch_size,
                               num_workers=4,
                               drop_last=True)

    print("==>model inference..")

    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           bn_dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, bn_dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats = []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = f
                    if i == 1:
                        ff[:, 2048:] = f
            else:
                ff = model(data).data.cpu()
            feats.append(ff)

    all_feature = torch.cat(feats, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)
        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        # fast by gpu
        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
    print('feature shape:', all_feature.size())
    if args.pseudo:
        print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
            args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints))

        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)

        if args.pseudo_visual:
            all_distmat, kdist = get_sparse_distmat(
                all_feature,
                eps=args.pseudo_eps + 0.1,
                len_slice=2000,
                use_gpu=True,
                dist_k=args.pseudo_minpoints)
            plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5)
            plt.savefig('test_kdist.png')
            plt.savefig(save_dir + 'test_kdist.png')
        else:
            all_distmat = get_sparse_distmat(all_feature,
                                             eps=args.pseudo_eps + 0.1,
                                             len_slice=2000,
                                             use_gpu=True)

        # print(all_distmat.todense()[0])

        pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps,
                                            args.pseudo_minpoints,
                                            args.pseudo_maxpoints,
                                            args.pseudo_algorithm)
        print("pseudo cost: {}s".format(time.time() - st))
        print("pseudo id cnt:", len(pseudolabels))
        print("pseudo img cnt:",
              len([x for k, v in pseudolabels.items() for x in v]))

        # # save
        all_list = query_list + gallery_list
        save_path = args.pseudo_savepath

        pid = args.pseudo_startid
        camid = 0
        for k, v in pseudolabels.items():
            os.makedirs(os.path.join(save_path, str(pid)), exist_ok=True)
            for _index in pseudolabels[k]:
                filename = all_list[_index].split("/")[-1]
                new_filename = str(pid) + "_c" + str(camid) + ".png"
                shutil.copy(all_list[_index],
                            os.path.join(save_path, str(pid), new_filename))
                camid += 1
            pid += 1

    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]

    if use_rerank:
        print("==>use re_rank")
        st = time.time()
        distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
        print("re_rank cost:", time.time() - st)
    else:
        st = time.time()
        weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
        print('==> using mgn_euclidean_dist')
        print('==> using weights:', weights)
        # distmat = euclidean_dist(query_feat, gallery_feat)
        distmat = None
        if use_flip:
            for i in range(2):
                dist = mgn_euclidean_dist(
                    query_feat[:, i * 2048:(i + 1) * 2048],
                    gallery_feat[:, i * 2048:(i + 1) * 2048],
                    norm=True,
                    weights=weights)
                if distmat is None:
                    distmat = dist / 2.0
                else:
                    distmat += dist / 2.0
        else:
            distmat = mgn_euclidean_dist(query_feat,
                                         gallery_feat,
                                         norm=True,
                                         weights=weights)

        print("euclidean_dist cost:", time.time() - st)

        distmat = distmat.numpy()

    num_q, num_g = distmat.shape
    print("==>saving..")
    if args.post:
        qfnames = [fname.split('/')[-1] for fname in query_list]
        gfnames = [fname.split('/')[-1] for fname in gallery_list]
        st = time.time()
        print("post json using top_per:", args.post_top_per)
        res_dict = get_post_json(distmat, qfnames, gfnames, args.post_top_per)
        print("post cost:", time.time() - st)
    else:
        # [todo] fast test
        print("==>sorting..")
        st = time.time()
        indices = np.argsort(distmat, axis=1)
        print("argsort cost:", time.time() - st)
        # print(indices[:2, :max_rank])
        # st = time.time()
        # indices = np.argpartition( distmat, range(1,max_rank+1))
        # print("argpartition cost:",time.time()-st)
        # print(indices[:2, :max_rank])

        max_200_indices = indices[:, :max_rank]
        res_dict = dict()
        for q_idx in range(num_q):
            filename = query_list[q_idx].split('/')[-1]
            max_200_files = [
                gallery_list[i].split('/')[-1] for i in max_200_indices[q_idx]
            ]
            res_dict[filename] = max_200_files
    if args.dba:
        save_fname = 'mgnsub_dba.json'
    elif args.aqe:
        save_fname = 'mgnsub_aqe.json'
    else:
        save_fname = 'mgnsub.json'
    if use_rerank:
        save_fname = 'rerank_' + save_fname
    if args.adabn:
        if args.adabn_all:
            save_fname = 'adabnall_' + save_fname
        else:
            save_fname = 'adabn_' + save_fname
    if use_flip:
        save_fname = 'flip_' + save_fname
    if args.post:
        save_fname = 'post_' + save_fname
    save_fname = args.save_fname + save_fname
    print('savefname:', save_fname)
    with open(save_dir + save_fname, 'w', encoding='utf-8') as f:
        json.dump(res_dict, f)
    with open(save_dir + save_fname.replace('.json', '.pkl'), 'wb') as fid:
        pickle.dump(distmat, fid, -1)
示例#14
0
def inference_val(args,
                  model,
                  dataloader,
                  num_query,
                  save_dir,
                  k1=20,
                  k2=6,
                  p=0.3,
                  use_rerank=False,
                  use_flip=False,
                  n_randperm=0,
                  bn_keys=[]):
    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats, pids, camids = [], [], []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data, pid, camid, _ = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = f
                    if i == 1:
                        ff[:, 2048:] = f
                # ff = F.normalize(ff, p=2, dim=1)
            else:
                ff = model(data).data.cpu()
                # ff = F.normalize(ff, p=2, dim=1)

            feats.append(ff)
            pids.append(pid)
            camids.append(camid)
    all_feature = torch.cat(feats, dim=0)
    # all_feature = all_feature[:,:1024+512]
    pids = torch.cat(pids, dim=0)
    camids = torch.cat(camids, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)
        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)

        # norm_feature = F.normalize(all_feature, p=2, dim=1)
        # norm_feature = norm_feature.numpy()
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_nonorm_func(norm_feature[i],all_norm_feature_T=norm_feature.T ,all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)

        # part 2hour for val
        # part_norm_feat = []
        # for i in range(8):
        #     norm_feature = F.normalize(all_feature[:,i*256:(i+1)*256], p=2, dim=1)
        #     part_norm_feat.append(norm_feature)
        # norm_feature = torch.cat(part_norm_feat,dim=1)

        # norm_feature = norm_feature.numpy()
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(mgn_aqe_func(norm_feature[i],all_norm_feature_T=norm_feature.T ,all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)

        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
        # import pdb;pdb.set_trace()

    print('feature shape:', all_feature.size())

    #
    # for k1 in range(5,10,2):
    #     for k2 in range(2,5,1):
    #         for l in range(5,8):
    #             p = l*0.1

    if n_randperm <= 0:
        k2 = args.k2
        gallery_feat = all_feature[num_query:]
        query_feat = all_feature[:num_query]

        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = None
        if use_rerank:
            print('==> using rerank')
            distmat = re_rank(query_feat, gallery_feat, args.k1, args.k2, p)
        else:
            # if args.aqe:
            #     print('==> using euclidean_dist')
            #     distmat = euclidean_dist(query_feat, gallery_feat)
            # else:
            weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
            print('==> using mgn_euclidean_dist')
            print('==> using weights:', weights)
            if use_flip:
                for i in range(2):
                    dist = mgn_euclidean_dist(
                        query_feat[:, i * 2048:(i + 1) * 2048],
                        gallery_feat[:, i * 2048:(i + 1) * 2048],
                        norm=True,
                        weights=weights)
                    if distmat is None:
                        distmat = dist / 2.0
                    else:
                        distmat += dist / 2.0
            else:
                distmat = mgn_euclidean_dist(query_feat,
                                             gallery_feat,
                                             norm=True,
                                             weights=weights)

        cmc, mAP, _ = eval_func(distmat, query_pid.numpy(),
                                gallery_pid.numpy(), query_camid.numpy(),
                                gallery_camid.numpy())
    else:
        k2 = args.k2
        torch.manual_seed(0)
        cmc = 0
        mAP = 0
        for i in range(n_randperm):
            index = torch.randperm(all_feature.size()[0])

            query_feat = all_feature[index][:num_query]
            gallery_feat = all_feature[index][num_query:]

            query_pid = pids[index][:num_query]
            query_camid = camids[index][:num_query]

            gallery_pid = pids[index][num_query:]
            gallery_camid = camids[index][num_query:]

            if use_rerank:
                print('==> using rerank')
                st = time.time()
                distmat = re_rank(query_feat, gallery_feat, args.k1, args.k2,
                                  p)
                print("re_rank cost:", time.time() - st)

            else:
                print('==> using euclidean_dist')
                st = time.time()
                # distmat = euclidean_dist(query_feat, gallery_feat)
                # weights = [1,1,1,1,1,1,1,1]
                weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
                # weights = [2,1,1,1/2,1/2,1/3,1/3,1/3]

                print('==> using mgn_euclidean_dist')
                print('==> using weights:', weights)
                # distmat = euclidean_dist(query_feat, gallery_feat)
                distmat = None
                if use_flip:
                    for i in range(2):
                        dist = mgn_euclidean_dist(
                            query_feat[:, i * 2048:(i + 1) * 2048],
                            gallery_feat[:, i * 2048:(i + 1) * 2048],
                            norm=True,
                            weights=weights)
                        if distmat is None:
                            distmat = dist / 2.0
                        else:
                            distmat += dist / 2.0
                else:
                    distmat = mgn_euclidean_dist(query_feat,
                                                 gallery_feat,
                                                 norm=True,
                                                 weights=weights)
                print("euclidean_dist cost:", time.time() - st)

            _cmc, _mAP, _ = eval_func(distmat, query_pid.numpy(),
                                      gallery_pid.numpy(), query_camid.numpy(),
                                      gallery_camid.numpy())
            cmc += _cmc / n_randperm
            mAP += _mAP / n_randperm

    print('Validation Result:')
    if use_rerank:
        print(str(k1) + "  -  " + str(k2) + "  -  " + str(p))
    print('mAP: {:.2%}'.format(mAP))
    for r in [1, 5, 10]:
        print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    print('average of mAP and rank1: {:.2%}'.format((mAP + cmc[0]) / 2.0))

    with open(save_dir + 'eval.txt', 'a') as f:
        if use_rerank:
            f.write('==> using rerank\n')
            f.write(str(k1) + "  -  " + str(k2) + "  -  " + str(p) + "\n")
        else:
            f.write('==> using euclidean_dist\n')

        f.write('mAP: {:.2%}'.format(mAP) + "\n")
        for r in [1, 5, 10]:
            f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]) + "\n")
        f.write('average of mAP and rank1: {:.2%}\n'.format(
            (mAP + cmc[0]) / 2.0))

        f.write('------------------------------------------\n')
        f.write('------------------------------------------\n')
        f.write('\n\n')
示例#15
0
def re_ranking_batch_gpu(all_feature,
                         q_num,
                         k1,
                         k2,
                         lambda_value,
                         len_slice=1000):

    # calculate (q+g)*(q+g)
    initial_rank = np.zeros((len(all_feature), k1 + 1)).astype(np.int32)

    original_dist = np.zeros((q_num, len(all_feature)))
    gpu_features = all_feature.cuda()
    s_time = time.time()

    n_iter = len(all_feature) // len_slice + int(
        len(all_feature) % len_slice > 0)

    with tqdm(total=n_iter) as pbar:
        for i in range(n_iter):
            dis_i_qg = euclidean_dist(
                gpu_features[i * len_slice:(i + 1) * len_slice],
                gpu_features).data.cpu().numpy()
            initial_i_rank = np.argpartition(
                dis_i_qg,
                range(1, k1 + 1),
            ).astype(np.int32)[:, :k1 + 1]
            initial_rank[i * len_slice:(i + 1) * len_slice] = initial_i_rank
            pbar.update(1)
    # print(initial_rank[0])

    end_time = time.time()
    print("rank time : %s" % (end_time - s_time))

    all_V = []

    s_time = time.time()

    n_iter = len(all_feature) // len_slice + int(
        len(all_feature) % len_slice > 0)

    with tqdm(total=n_iter) as pbar:
        for i in range(n_iter):
            dis_i_qg = euclidean_dist(
                gpu_features[i * len_slice:(i + 1) * len_slice],
                gpu_features).data.cpu().numpy()
            for ks in range(dis_i_qg.shape[0]):
                r_k = i * len_slice + ks
                dis_i_qg[ks] = np.power(dis_i_qg[ks], 2).astype(np.float32)
                dis_i_qg[ks] = 1. * dis_i_qg[ks] / np.max(dis_i_qg[ks])
                if r_k < q_num:
                    original_dist[r_k] = dis_i_qg[ks]
                V, k_reciprocal_expansion_index, weight = calculate_V(
                    initial_rank, len(all_feature), dis_i_qg[ks], r_k, k1)
                # if r_k == 0:
                #     print(k_reciprocal_expansion_index)
                #     print(weight)
                #     print(dis_i_qg[ks])
                all_V.append(sparse.csr_matrix(V))

            pbar.update(1)

    all_V = sparse.vstack(all_V)
    # print(all_V.getrow(0).toarray())
    end_time = time.time()
    print("calculate V time : %s" % (end_time - s_time))
    # print(all_V.todense()[0])

    all_V_qe = []
    s_time = time.time()
    for i in range(len(all_feature)):
        temp_V = np.zeros((k2, len(all_feature)))
        for l, row_index in enumerate(initial_rank[i, :k2]):
            temp_V[l, :] = all_V.getrow(row_index).toarray()[0]

        V_qe = np.mean(temp_V, axis=0)
        all_V_qe.append(sparse.csr_matrix(V_qe))
    all_V_qe = sparse.vstack(all_V_qe)
    # print(all_V_qe.todense()[0])
    del all_V
    end_time = time.time()
    print("calculate V_qe time : %s" % (end_time - s_time))

    invIndex = []
    for i in range(len(all_feature)):
        invIndex.append(
            np.where(all_V_qe.getcol(i).toarray().transpose()[0] != 0)[0])
    jaccard_dist = np.zeros_like(original_dist, dtype=np.float32)

    with tqdm(total=q_num) as pbar:
        for i in range(q_num):
            temp_min = np.zeros(shape=[1, len(all_feature)], dtype=np.float32)

            indNonZero = np.where(all_V_qe.getrow(i).toarray()[0] != 0)[0]

            indImages = []
            indImages = [invIndex[ind] for ind in indNonZero]
            # print(indImages)
            for j in range(len(indNonZero)):
                # print(indNonZero[j])
                c = all_V_qe.getrow(i).getcol(indNonZero[j]).toarray()[0, 0]
                # print(c)
                # print(indImages[j])

                t_min = np.zeros((indImages[j].shape[0]))
                for kk in range(indImages[j].shape[0]):
                    temp_d = all_V_qe.getrow(indImages[j][kk]).getcol(
                        indNonZero[j]).toarray()[0, 0]
                    t_min[kk] = np.minimum(c, temp_d)
                # print(t_min)

                temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + t_min
                # temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
                #                                                                    V[indImages[j], indNonZero[j]])
            jaccard_dist[i] = 1 - temp_min / (2. - temp_min)
            pbar.update(1)
    # print(jaccard_dist[0])
    # print(original_dist[0])
    final_dist = jaccard_dist * (1 -
                                 lambda_value) + original_dist * lambda_value
    del original_dist
    del all_V_qe
    del jaccard_dist
    final_dist = final_dist[:q_num, q_num:]
    return final_dist


# def re_ranking_batch(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
#
#     # The following naming, e.g. gallery_num, is different from outer scope.
#     # Don't care about it.
#
#     original_dist = np.concatenate(
#       [np.concatenate([q_q_dist, q_g_dist], axis=1),
#        np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
#       axis=0)
#     original_dist = np.power(original_dist, 2).astype(np.float32)
#     original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))
#     V = np.zeros_like(original_dist).astype(np.float32)
#     # initial_rank = np.argsort(original_dist).astype(np.int32)
#     # # fast sort top K1+1
#     initial_rank = np.argpartition( original_dist, range(1,k1+1)).astype(np.int32)
#
#     query_num = q_g_dist.shape[0]
#     gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]
#     all_num = gallery_num
#
#     for i in range(all_num):
#         # k-reciprocal neighbors
#         forward_k_neigh_index = initial_rank[i,:k1+1]
#         backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
#         fi = np.where(backward_k_neigh_index==i)[0]
#         k_reciprocal_index = forward_k_neigh_index[fi]
#         k_reciprocal_expansion_index = k_reciprocal_index
#         for j in range(len(k_reciprocal_index)):
#             candidate = k_reciprocal_index[j]
#             candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1]
#             candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1]
#             fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
#             candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
#             if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
#                 k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
#
#         k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
#         weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
#         V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)
#     original_dist = original_dist[:query_num,]
#     if k2 != 1:
#         V_qe = np.zeros_like(V,dtype=np.float32)
#         for i in range(all_num):
#             V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
#
#         # # DBA
#         # alpha = -3.0
#         # weights = np.logspace(0,alpha,k2).reshape((-1,1))
#         # for i in range(all_num):
#         #     V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:]*weights,axis=0)
#
#         V = V_qe
#         del V_qe
#     del initial_rank
#     invIndex = []
#     for i in range(gallery_num):
#         invIndex.append(np.where(V[:,i] != 0)[0])
#
#     jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)
#
#
#     for i in range(query_num):
#         temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32)
#         indNonZero = np.where(V[i,:] != 0)[0]
#         indImages = []
#         indImages = [invIndex[ind] for ind in indNonZero]
#         for j in range(len(indNonZero)):
#             temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
#         jaccard_dist[i] = 1-temp_min/(2.-temp_min)
#
#     final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
#     del original_dist
#     del V
#     del jaccard_dist
#     final_dist = final_dist[:query_num,query_num:]
#     return final_dist
示例#16
0
    if len(img_list) == 1:
        continue
    for i in range(iter_n):
        # print("batch ----%d----" % (i))

        batch_data = img_data[i * batch_size:(i + 1) * batch_size]
        with torch.no_grad():
            # batch_feature = model(batch_data).detach().cpu()

            ff = torch.FloatTensor(batch_data.size(0), 2048).zero_()
            for i in range(2):
                if i == 1:
                    batch_data = batch_data.index_select(3, torch.arange(batch_data.size(3) - 1, -1, -1).long())
                outputs = model(batch_data)
                f = outputs.data.cpu()
                ff = ff + f

            fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
            ff = ff.div(fnorm.expand_as(ff))

            all_feature.append(ff)
    all_feature = torch.cat(all_feature)
    gallery_feat = all_feature
    query_feat = all_feature
    distmat = euclidean_dist(query_feat, gallery_feat).numpy()
    for m in range(distmat.shape[0]):
        for n in range(distmat.shape[1]):
            v = distmat[m, n]
            if v > 0.5:
                print(imgs[m], imgs[n])
示例#17
0
def inference_samples(args,
                      model,
                      transform,
                      batch_size,
                      query_txt,
                      query_dir,
                      gallery_dir,
                      save_dir,
                      k1=20,
                      k2=6,
                      p=0.3,
                      use_rerank=False,
                      use_flip=False,
                      max_rank=200,
                      bn_keys=[]):
    print("==>load data info..")
    if query_txt != "":
        query_list = list()
        with open(query_txt, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                data = line.split(" ")
                image_name = data[0].split("/")[1]
                img_file = os.path.join(query_dir, image_name)
                query_list.append(img_file)
    else:
        query_list = [
            os.path.join(query_dir, x) for x in os.listdir(query_dir)
        ]
    gallery_list = [
        os.path.join(gallery_dir, x) for x in os.listdir(gallery_dir)
    ]
    query_num = len(query_list)
    if args.save_fname != '':
        print(query_list[:10])
        query_list = sorted(query_list)
        print(query_list[:10])
        gallery_list = sorted(gallery_list)
    print("==>build dataloader..")
    image_set = ImageDataset(query_list + gallery_list, transform)
    dataloader = DataLoader(image_set,
                            sampler=SequentialSampler(image_set),
                            batch_size=batch_size,
                            num_workers=6)
    bn_dataloader = DataLoader(image_set,
                               sampler=RandomSampler(image_set),
                               batch_size=batch_size,
                               num_workers=6,
                               drop_last=True)

    print("==>model inference..")

    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           bn_dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, bn_dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats = []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = F.normalize(f, p=2, dim=1)
                    if i == 1:
                        ff[:, 2048:] = F.normalize(f, p=2, dim=1)
                ff = F.normalize(ff, p=2, dim=1)
            else:
                ff = model(data).data.cpu()
                ff = F.normalize(ff, p=2, dim=1)
            feats.append(ff)

    all_feature = torch.cat(feats, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        # fast by gpu
        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
    print('feature shape:', all_feature.size())
    if args.pseudo:
        print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
            args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints))

        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)
        if args.pseudo_hist:
            print("==> predict histlabel...")

            img_filenames = query_list + gallery_list
            img_idx = list(range(len(img_filenames)))
            imgs = {'filename': img_filenames, 'identity': img_idx}
            df_img = pd.DataFrame(imgs)
            hist_labels = mmcv.track_parallel_progress(img_hist_predictor,
                                                       df_img['filename'], 6)
            print("hist label describe..")
            unique_hist_labels = sorted(list(set(hist_labels)))

            hl_idx = []

            hl_query_infos = []
            hl_gallery_infos = []

            for label_idx in range(len(unique_hist_labels)):
                hl_query_infos.append([])
                hl_gallery_infos.append([])
                hl_idx.append([])
            with tqdm(total=len(img_filenames)) as pbar:
                for idx, info in enumerate(img_filenames):
                    for label_idx in range(len(unique_hist_labels)):
                        if hist_labels[idx] == unique_hist_labels[label_idx]:
                            if idx < len(query_list):
                                hl_query_infos[label_idx].append(info)
                            else:
                                hl_gallery_infos[label_idx].append(info)
                            hl_idx[label_idx].append(idx)
                    pbar.update(1)
            for label_idx in range(len(unique_hist_labels)):
                print('hist_label:', unique_hist_labels[label_idx],
                      ' query number:', len(hl_query_infos[label_idx]))
                print('hist_label:', unique_hist_labels[label_idx],
                      ' gallery number:', len(hl_gallery_infos[label_idx]))
                print(
                    'hist_label:', unique_hist_labels[label_idx],
                    ' q+g number:',
                    len(hl_query_infos[label_idx]) +
                    len(hl_gallery_infos[label_idx]))
                print('hist_label:', unique_hist_labels[label_idx],
                      ' idx q+g number:', len(hl_idx[label_idx]))
            # pseudo
            pid = args.pseudo_startid
            camid = 0
            all_list = query_list + gallery_list
            save_path = args.pseudo_savepath
            pseudo_eps = args.pseudo_eps
            pseudo_minpoints = args.pseudo_minpoints
            for label_idx in range(len(unique_hist_labels)):
                # if label_idx == 0:
                #     pseudo_eps = 0.6
                # else:
                #     pseudo_eps = 0.75
                if label_idx == 0:
                    pseudo_eps = 0.65
                else:
                    pseudo_eps = 0.80

                feature = all_feature[hl_idx[label_idx]]
                img_list = [all_list[idx] for idx in hl_idx[label_idx]]

                print("==> get sparse distmat!")
                if args.pseudo_visual:
                    all_distmat, kdist = get_sparse_distmat(
                        feature,
                        eps=pseudo_eps + 0.05,
                        len_slice=2000,
                        use_gpu=True,
                        dist_k=pseudo_minpoints)
                    plt.plot(list(range(len(kdist))),
                             np.sort(kdist),
                             linewidth=0.5)
                    plt.savefig('test_kdist_hl{}_eps{}_{}.png'.format(
                        label_idx, pseudo_eps, pseudo_minpoints))
                    plt.savefig(save_dir +
                                'test_kdist_hl{}_eps{}_{}.png'.format(
                                    label_idx, pseudo_eps, pseudo_minpoints))
                else:
                    all_distmat = get_sparse_distmat(feature,
                                                     eps=pseudo_eps + 0.05,
                                                     len_slice=2000,
                                                     use_gpu=True)

                print("==> predict pseudo label!")
                pseudolabels = predict_pseudo_label(all_distmat, pseudo_eps,
                                                    pseudo_minpoints,
                                                    args.pseudo_maxpoints,
                                                    args.pseudo_algorithm)
                print(
                    "==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
                        pseudo_eps, pseudo_minpoints, args.pseudo_maxpoints))
                print("pseudo cost: {}s".format(time.time() - st))
                print("pseudo id cnt:", len(pseudolabels))
                print("pseudo img cnt:",
                      len([x for k, v in pseudolabels.items() for x in v]))

                if label_idx == 0:
                    sf = 1
                else:
                    sf = 1
                sample_id_cnt = 0
                sample_file_cnt = 0
                nignore_query = 0
                for i, (k, v) in enumerate(pseudolabels.items()):
                    if i % sf != 0:
                        continue
                    # query_cnt = 0
                    # for _index in pseudolabels[k]:
                    #     if _index<len(query_list):
                    #         query_cnt += 1
                    # if query_cnt>=2:
                    #     nignore_query += 1
                    #     continue
                    os.makedirs(os.path.join(save_path, str(pid)),
                                exist_ok=True)

                    for _index in pseudolabels[k]:
                        filename = img_list[_index].split("/")[-1]
                        new_filename = str(pid) + "_c" + str(camid) + ".png"
                        shutil.copy(
                            img_list[_index],
                            os.path.join(save_path, str(pid), new_filename))
                        camid += 1
                        sample_file_cnt += 1
                    sample_id_cnt += 1
                    pid += 1
                print("pseudo ignore id cnt:", nignore_query)
                print("sample id cnt:", sample_id_cnt)
                print("sample file cnt:", sample_file_cnt)
        else:
            if args.pseudo_visual:
                all_distmat, kdist = get_sparse_distmat(
                    all_feature,
                    eps=args.pseudo_eps + 0.05,
                    len_slice=2000,
                    use_gpu=True,
                    dist_k=args.pseudo_minpoints)
                plt.plot(list(range(len(kdist))),
                         np.sort(kdist),
                         linewidth=0.5)
                plt.savefig('test_kdist.png')
                plt.savefig(save_dir + 'test_kdist.png')
            else:
                all_distmat = get_sparse_distmat(all_feature,
                                                 eps=args.pseudo_eps + 0.05,
                                                 len_slice=2000,
                                                 use_gpu=True)

            # print(all_distmat.todense()[0])

            pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps,
                                                args.pseudo_minpoints,
                                                args.pseudo_maxpoints,
                                                args.pseudo_algorithm)
            print("pseudo cost: {}s".format(time.time() - st))
            print("pseudo id cnt:", len(pseudolabels))
            print("pseudo img cnt:",
                  len([x for k, v in pseudolabels.items() for x in v]))

            # # save
            all_list = query_list + gallery_list
            save_path = args.pseudo_savepath

            pid = args.pseudo_startid
            camid = 0
            nignore_query = 0
            for k, v in pseudolabels.items():
                os.makedirs(os.path.join(save_path, str(pid)), exist_ok=True)
                # [fileter]
                query_cnt = 0
                for _index in pseudolabels[k]:
                    if _index < len(query_list):
                        query_cnt += 1
                if query_cnt >= 4:
                    nignore_query += 1
                    continue
                for _index in pseudolabels[k]:
                    filename = all_list[_index].split("/")[-1]
                    new_filename = str(pid) + "_c" + str(camid) + ".png"
                    shutil.copy(
                        all_list[_index],
                        os.path.join(save_path, str(pid), new_filename))
                    camid += 1
                pid += 1
            print("pseudo ignore id cnt:", nignore_query)
    else:
        gallery_feat = all_feature[query_num:]
        query_feat = all_feature[:query_num]

        if use_rerank:
            print("==>use re_rank")
            st = time.time()
            k2 = args.k2
            # distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
            num_query = len(query_feat)
            print("using k1:{} k2:{} lambda:{}".format(args.k1, args.k2, p))
            distmat = re_ranking_batch_gpu(
                torch.cat([query_feat, gallery_feat], dim=0), num_query,
                args.k1, args.k2, p)

            print("re_rank cost:", time.time() - st)

        else:
            print("==>use euclidean_dist")
            st = time.time()
            distmat = euclidean_dist(query_feat, gallery_feat)
            print("euclidean_dist cost:", time.time() - st)

            distmat = distmat.numpy()

        num_q, num_g = distmat.shape
        print("==>saving..")
        if args.post:
            qfnames = [fname.split('/')[-1] for fname in query_list]
            gfnames = [fname.split('/')[-1] for fname in gallery_list]
            st = time.time()
            print("post json using top_per:", args.post_top_per)
            res_dict = get_post_json(distmat, qfnames, gfnames,
                                     args.post_top_per)
            print("post cost:", time.time() - st)
        else:
            # [todo] fast test
            print("==>sorting..")
            st = time.time()
            indices = np.argsort(distmat, axis=1)
            print("argsort cost:", time.time() - st)
            # print(indices[:2, :max_rank])
            # st = time.time()
            # indices = np.argpartition( distmat, range(1,max_rank+1))
            # print("argpartition cost:",time.time()-st)
            # print(indices[:2, :max_rank])

            max_200_indices = indices[:, :max_rank]
            res_dict = dict()
            for q_idx in range(num_q):
                filename = query_list[q_idx].split('/')[-1]
                max_200_files = [
                    gallery_list[i].split('/')[-1]
                    for i in max_200_indices[q_idx]
                ]
                res_dict[filename] = max_200_files
        if args.dba:
            save_fname = 'sub_dba.json'
        elif args.aqe:
            save_fname = 'sub_aqe.json'
        else:
            save_fname = 'sub.json'
        if use_rerank:
            save_fname = 'rerank_' + save_fname
        if args.adabn:
            if args.adabn_all:
                save_fname = 'adabnall_' + save_fname
            else:
                save_fname = 'adabn_' + save_fname
        if use_flip:
            save_fname = 'flip_' + save_fname
        if args.post:
            save_fname = 'post_' + save_fname
        save_fname = args.save_fname + save_fname
        print('savefname:', save_fname)
        with open(save_dir + save_fname, 'w', encoding='utf-8') as f:
            json.dump(res_dict, f)
        with open(save_dir + save_fname.replace('.json', '.pkl'), 'wb') as fid:
            pickle.dump(distmat, fid, -1)
示例#18
0
def inference_val(args,
                  model,
                  dataloader,
                  num_query,
                  save_dir,
                  k1=20,
                  k2=6,
                  p=0.3,
                  use_rerank=False,
                  use_flip=False,
                  n_randperm=0,
                  bn_keys=[]):
    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats, pids, camids = [], [], []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data, pid, camid, _ = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = F.normalize(f, p=2, dim=1)
                    if i == 1:
                        ff[:, 2048:] = F.normalize(f, p=2, dim=1)
                ff = F.normalize(ff, p=2, dim=1)
                # ff = torch.FloatTensor(data.size(0), 2048).zero_()
                # for i in range(2):
                #     if i == 1:
                #         data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda'))
                #     outputs = model(data)
                #     f = outputs.data.cpu()
                #     ff = ff + f
                # fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                # ff = ff.div(fnorm.expand_as(ff))
            else:
                ff = model(data).data.cpu()
                ff = F.normalize(ff, p=2, dim=1)

            feats.append(ff)
            pids.append(pid)
            camids.append(camid)
    all_feature = torch.cat(feats, dim=0)
    # all_feature = all_feature[:,:1024+512]
    pids = torch.cat(pids, dim=0)
    camids = torch.cat(camids, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        # # [todo] remove norma; normalize is used to to make sure the similiar one is itself
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        # sims = torch.mm(all_feature, all_feature.t()).numpy()

        # # [todo] heap sort
        # # initial_rank = sims.argsort(axis=1)[:,::-1]
        # initial_rank = np.argpartition(-sims,range(1,k2+1))

        # all_feature = all_feature.numpy()

        # V_qe = np.zeros_like(all_feature,dtype=np.float32)

        # # [todo] update query feature only?
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         # get weights from similarity
        #         weights = sims[i,initial_rank[i,:k2]].reshape((-1,1))
        #         # weights = (weights-weights.min())/(weights.max()-weights.min())
        #         weights = np.power(weights,alpha)
        #         # import pdb;pdb.set_trace()

        #         V_qe[i,:] = np.mean(all_feature[initial_rank[i,:k2],:]*weights,axis=0)
        #         pbar.update(1)
        # # import pdb;pdb.set_trace()
        # all_feature = V_qe
        # del V_qe
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)

        # func = functools.partial(aqe_func,all_feature=all_feature,k2=k2,alpha=alpha)
        # all_features = mmcv.track_parallel_progress(func, all_feature, 6)

        # cpu
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)
        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)

        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
        # import pdb;pdb.set_trace()

    if args.pseudo:
        print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
            args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints))
        st = time.time()
        # cal sparse distmat
        all_feature = F.normalize(all_feature, p=2, dim=1)

        # all_distmat = euclidean_dist(all_feature, all_feature).numpy()
        # print(all_distmat[0])
        # pred1 = predict_pseudo_label(all_distmat,args.pseudo_eps,args.pseudo_minpoints,args.pseudo_maxpoints,args.pseudo_algorithm)
        # print(list(pred1.keys())[:10])

        if args.pseudo_visual:
            all_distmat, kdist = get_sparse_distmat(
                all_feature,
                eps=args.pseudo_eps + 0.1,
                len_slice=2000,
                use_gpu=True,
                dist_k=args.pseudo_minpoints)
            plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5)
            plt.savefig('eval_kdist.png')
            plt.savefig(save_dir + 'eval_kdist.png')
        else:
            all_distmat = get_sparse_distmat(all_feature,
                                             eps=args.pseudo_eps + 0.1,
                                             len_slice=2000,
                                             use_gpu=True)

        # print(all_distmat.todense()[0])

        pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps,
                                            args.pseudo_minpoints,
                                            args.pseudo_maxpoints,
                                            args.pseudo_algorithm)
        print("pseudo cost: {}s".format(time.time() - st))
        print("pseudo id cnt:", len(pseudolabels))
        print("pseudo img cnt:",
              len([x for k, v in pseudolabels.items() for x in v]))
        print("pseudo cost: {}s".format(time.time() - st))
        # print(list(pred.keys())[:10])
    print('feature shape:', all_feature.size())

    #
    # for k1 in range(5,10,2):
    #     for k2 in range(2,5,1):
    #         for l in range(5,8):
    #             p = l*0.1

    if n_randperm <= 0:
        k2 = args.k2
        gallery_feat = all_feature[num_query:]
        query_feat = all_feature[:num_query]

        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        if use_rerank:
            print('==> using rerank')
            # distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
            distmat = re_ranking_batch_gpu(
                torch.cat([query_feat, gallery_feat], dim=0), num_query,
                args.k1, args.k2, p)
        else:
            print('==> using euclidean_dist')
            distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(distmat, query_pid.numpy(),
                                gallery_pid.numpy(), query_camid.numpy(),
                                gallery_camid.numpy())
    else:
        k2 = args.k2
        torch.manual_seed(0)
        cmc = 0
        mAP = 0
        for i in range(n_randperm):
            index = torch.randperm(all_feature.size()[0])

            query_feat = all_feature[index][:num_query]
            gallery_feat = all_feature[index][num_query:]

            query_pid = pids[index][:num_query]
            query_camid = camids[index][:num_query]

            gallery_pid = pids[index][num_query:]
            gallery_camid = camids[index][num_query:]

            if use_rerank:
                print('==> using rerank')
                st = time.time()
                # distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
                distmat = re_ranking_batch_gpu(
                    torch.cat([query_feat, gallery_feat], dim=0), num_query,
                    args.k1, args.k2, p)

                print("re_rank cost:", time.time() - st)

            else:
                print('==> using euclidean_dist')
                st = time.time()
                distmat = euclidean_dist(query_feat, gallery_feat)
                print("euclidean_dist cost:", time.time() - st)

            _cmc, _mAP, _ = eval_func(distmat, query_pid.numpy(),
                                      gallery_pid.numpy(), query_camid.numpy(),
                                      gallery_camid.numpy())
            cmc += _cmc / n_randperm
            mAP += _mAP / n_randperm

    print('Validation Result:')
    if use_rerank:
        print(str(k1) + "  -  " + str(k2) + "  -  " + str(p))
    print('mAP: {:.2%}'.format(mAP))
    for r in [1, 5, 10]:
        print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    print('average of mAP and rank1: {:.2%}'.format((mAP + cmc[0]) / 2.0))

    with open(save_dir + 'eval.txt', 'a') as f:
        if use_rerank:
            f.write('==> using rerank\n')
            f.write(str(k1) + "  -  " + str(k2) + "  -  " + str(p) + "\n")
        else:
            f.write('==> using euclidean_dist\n')

        f.write('mAP: {:.2%}'.format(mAP) + "\n")
        for r in [1, 5, 10]:
            f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]) + "\n")
        f.write('average of mAP and rank1: {:.2%}\n'.format(
            (mAP + cmc[0]) / 2.0))

        f.write('------------------------------------------\n')
        f.write('------------------------------------------\n')
        f.write('\n\n')