예제 #1
0
    def get_result(self, show):
        """

        :param show: 是否显示查询出的结果
        :return None
        """
        tbar = tqdm.tqdm(self.test_dataloader)
        features_all, names_all = [], []
        with torch.no_grad():
            for i, (images, names) in enumerate(tbar):
                # 完成网络的前向传播
                # features = self.solver.forward((images, torch.zeros((images.size(0)))))[-1]
                features = self.solver.tta(
                    (images, torch.zeros((images.size(0)))))

                features_all.append(features.detach().cpu())
                names_all.extend(names)

        features_all = torch.cat(features_all, dim=0)
        query_features = features_all[:self.num_query]
        gallery_features = features_all[self.num_query:]

        query_names = np.array(names_all[:self.num_query])
        gallery_names = np.array(names_all[self.num_query:])

        if self.dist == 're_rank':
            distmat = re_rank(query_features, gallery_features)
        elif self.dist == 'cos_dist':
            distmat = cos_dist(query_features, gallery_features)
        elif self.dist == 'euclidean_dist':
            distmat = euclidean_dist(query_features, gallery_features)
        else:
            assert "Not implemented :{}".format(self.dist)

        result = {}
        for query_index, query_dist in enumerate(distmat):
            choose_index = np.argsort(query_dist)[:self.num_choose]
            query_name = query_names[query_index]
            gallery_name = gallery_names[choose_index]
            result[query_name] = gallery_name.tolist()
            if query_name in self.demo_names:
                self.show_result(query_name, gallery_name, 5, show)

        with codecs.open('./result.json', 'w', "utf-8") as json_file:
            json.dump(result, json_file, ensure_ascii=False)
예제 #2
0
    def validation(self, valid_loader):
        """ 完成模型的验证过程

        :param valid_loader: 验证集的Dataloader
        :return rank1: rank1得分;类型为float
        :return mAP: 平均检索精度;类型为float
        :return average_score: 平均得分;类型为float
        """
        self.model.eval()
        tbar = tqdm.tqdm(valid_loader)
        features_all, labels_all = [], []
        with torch.no_grad():
            for i, (images, labels, paths) in enumerate(tbar):
                # 完成网络的前向传播
                # features = self.solver.forward((images, labels))[-1]
                features = self.solver.tta((images, labels))
                features_all.append(features.detach().cpu())
                labels_all.append(labels)

        features_all = torch.cat(features_all, dim=0)
        labels_all = torch.cat(labels_all, dim=0)

        query_features = features_all[:self.num_query]
        query_labels = labels_all[:self.num_query]

        gallery_features = features_all[self.num_query:]
        gallery_labels = labels_all[self.num_query:]

        if self.dist == 're_rank':
            distmat = re_rank(query_features, gallery_features)
        elif self.dist == 'cos_dist':
            distmat = cos_dist(query_features, gallery_features)
        elif self.dist == 'euclidean_dist':
            distmat = euclidean_dist(query_features, gallery_features)
        else:
            assert "Not implemented :{}".format(self.dist)

        all_rank_precison, mAP, _ = eval_func(distmat, query_labels.numpy(), gallery_labels.numpy(),
                                              use_cython=self.cython)

        rank1 = all_rank_precison[0]
        average_score = 0.5 * rank1 + 0.5 * mAP
        print('Rank1: {:.2%}, mAP {:.2%}, average score {:.2%}'.format(rank1, mAP, average_score))
        return rank1, mAP, average_score
예제 #3
0
    def get_result(self, show):
        """

        :param show: 是否显示查询出的结果
        :return None
        """
        tbar = tqdm.tqdm(self.valid_dataloader)
        features_all, labels_all, paths_all = [], [], []
        with torch.no_grad():
            for i, (images, labels, paths) in enumerate(tbar):
                # 完成网络的前向传播
                # features = self.solver.forward(images)[-1]
                features = self.solver.tta(images)

                features_all.append(features.detach().cpu())
                labels_all.extend(labels)
                paths_all.extend(paths)

        features_all = torch.cat(features_all, dim=0)
        query_features = features_all[:self.num_query]
        gallery_features = features_all[self.num_query:]

        query_lables = np.array(labels_all[:self.num_query])
        gallery_labels = np.array(labels_all[self.num_query:])

        query_paths = np.array(paths_all[:self.num_query])
        gallery_paths = np.array(paths_all[self.num_query:])

        if self.dist == 're_rank':
            distmat = re_rank(query_features, gallery_features)
        elif self.dist == 'cos_dist':
            distmat = cos_dist(query_features, gallery_features)
        elif self.dist == 'euclidean_dist':
            distmat = euclidean_dist(query_features, gallery_features)
        else:
            assert "Not implemented :{}".format(self.dist)

        for query_index, query_dist in enumerate(distmat):
            choose_index = np.argsort(query_dist)[:self.num_choose]
            query_path = query_paths[query_index]
            gallery_path = gallery_paths[choose_index]
            query_label = query_lables[query_index]
            gallery_label = gallery_labels[choose_index]
            self.show_result(query_path, gallery_path, query_label, gallery_label, 5, show)
예제 #4
0
def inference_samples(model, transform, batch_size):  # 传入模型,数据预处理方法,batch_size
    query_list = list()
    with open(r'初赛A榜测试集/query_a_list.txt', 'r') as f:
        # 测试集中txt文件
        lines = f.readlines()
        for i, line in enumerate(lines):
            data = line.split(" ")
            image_name = data[0].split("/")[1]
            img_file = os.path.join(r'初赛A榜测试集\query_a',
                                    image_name)  # 测试集query文件夹
            query_list.append(img_file)

    gallery_list = [
        os.path.join(r'初赛A榜测试集\gallery_a', x) for x in  # 测试集gallery文件夹
        os.listdir(r'初赛A榜测试集\gallery_a')
    ]
    query_num = len(query_list)
    img_list = list()
    for q_img in query_list:
        q_img = read_image(q_img)
        q_img = transform(q_img)
        img_list.append(q_img)
    for g_img in gallery_list:
        g_img = read_image(g_img)
        g_img = transform(g_img)
        img_list.append(g_img)
    img_data = torch.Tensor([t.numpy() for t in img_list])
    model = model.to(device)
    model.eval()
    iter_n = len(img_list) // batch_size
    if len(img_list) % batch_size != 0:
        iter_n += 1
    all_feature = list()
    for i in range(iter_n):
        print("batch ----%d----" % (i))
        batch_data = img_data[i * batch_size:(i + 1) * batch_size]
        with torch.no_grad():
            batch_feature = model(batch_data).detach().cpu()
            all_feature.append(batch_feature)
    all_feature = torch.cat(all_feature)
    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]

    distmat = re_rank(query_feat, gallery_feat)  # rerank方法
    distmat = distmat  # 如果使用 euclidean_dist,不使用rerank改为:distamt = distamt.numpy()
    num_q, num_g = distmat.shape
    indices = np.argsort(distmat, axis=1)
    max_200_indices = indices[:, :200]

    res_dict = dict()
    for q_idx in range(num_q):
        print(query_list[q_idx])
        filename = query_list[q_idx][query_list[q_idx].rindex("\\") + 1:]
        max_200_files = [
            gallery_list[i][gallery_list[i].rindex("\\") + 1:]
            for i in max_200_indices[q_idx]
        ]
        res_dict[filename] = max_200_files

    with open(r'submission_A.json', 'w', encoding='utf-8') as f:  # 提交文件
        json.dump(res_dict, f)
예제 #5
0
def test(args):
    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    logger = setup_logger('reid_baseline.eval', cfg.OUTPUT_DIR, 0, train=False)

    logger.info('Running with config:\n{}'.format(cfg))

    _, val_dl, num_query, num_classes = make_dataloader(cfg)

    model = build_model(cfg, num_classes)
    if cfg.TEST.MULTI_GPU:
        model = nn.DataParallel(model)
        model = convert_model(model)
        logger.info('Use multi gpu to inference')
    para_dict = torch.load(cfg.TEST.WEIGHT)
    model.load_state_dict(para_dict)
    model.cuda()
    model.eval()

    feats, pids, camids, paths = [], [], [], []
    with torch.no_grad():
        for batch in tqdm(val_dl, total=len(val_dl), leave=False):
            data, pid, camid, path = batch
            paths.extend(list(path))
            data = data.cuda()
            feat = model(data).detach().cpu()
            feats.append(feat)
            pids.append(pid)
            camids.append(camid)
    feats = torch.cat(feats, dim=0)
    pids = torch.cat(pids, dim=0)
    camids = torch.cat(camids, dim=0)

    query_feat = feats[:num_query]
    query_pid = pids[:num_query]
    query_camid = camids[:num_query]
    query_path = np.array(paths[:num_query])

    gallery_feat = feats[num_query:]
    gallery_pid = pids[num_query:]
    gallery_camid = camids[num_query:]
    gallery_path = np.array(paths[num_query:])

    distmat = euclidean_dist(query_feat, gallery_feat)

    cmc, mAP, all_AP = eval_func(distmat.numpy(),
                                 query_pid.numpy(),
                                 gallery_pid.numpy(),
                                 query_camid.numpy(),
                                 gallery_camid.numpy(),
                                 use_cython=True)

    if cfg.TEST.VIS:
        worst_q = np.argsort(all_AP)[:cfg.TEST.VIS_Q_NUM]
        qid = query_pid[worst_q]
        q_im = query_path[worst_q]

        ind = np.argsort(distmat, axis=1)
        gid = gallery_pid[ind[worst_q]][..., :cfg.TEST.VIS_G_NUM]
        g_im = gallery_path[ind[worst_q]][..., :cfg.TEST.VIS_G_NUM]

        for idx in range(cfg.TEST.VIS_Q_NUM):
            sid = qid[idx] == gid[idx]
            im = rank_list_to_im(range(len(g_im[idx])), sid, q_im[idx],
                                 g_im[idx])

            im.save(
                osp.join(cfg.OUTPUT_DIR,
                         'worst_query_{}.jpg'.format(str(idx).zfill(2))))

    logger.info('Validation Result:')
    for r in cfg.TEST.CMC:
        logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    logger.info('mAP: {:.2%}'.format(mAP))
    logger.info('-' * 20)

    if not cfg.TEST.RERANK:
        return

    distmat = re_rank(query_feat, gallery_feat)
    cmc, mAP, all_AP = eval_func(distmat,
                                 query_pid.numpy(),
                                 gallery_pid.numpy(),
                                 query_camid.numpy(),
                                 gallery_camid.numpy(),
                                 use_cython=True)

    logger.info('ReRanking Result:')
    for r in cfg.TEST.CMC:
        logger.info('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    logger.info('mAP: {:.2%}'.format(mAP))
    logger.info('-' * 20)
예제 #6
0
def inference_val(model,  transform, batch_size, feature_dim, k1=20, k2=6, p=0.3, use_rerank=False):
    q_img_list = os.listdir(r'E:\data\reid\dataset7\query')
    query_list = list()
    qid_list = list()
    qcid_list = list()
    for q_img in q_img_list:
        query_list.append(os.path.join(r'E:\data\reid\dataset7\query', q_img))
        qid_list.append(int(q_img.strip(".png").split("_")[0]))
        qcid_list.append(int(q_img.strip(".png").split("_")[1].strip("c")))

    g_img_list = os.listdir(r'E:\data\reid\dataset7\gallery')
    gallery_list = list()
    gid_list = list()
    gcid_list = list()
    for g_img in g_img_list:
        gallery_list.append(os.path.join(r'E:\data\reid\dataset7\gallery', g_img))
        gid_list.append(int(g_img.strip(".png").split("_")[0]))
        gcid_list.append(int(g_img.strip(".png").split("_")[1].strip("c")))
    img_list = list()
    for q_img in query_list:
        q_img = read_image(q_img)
        q_img = transform(q_img)
        img_list.append(q_img)
    for g_img in gallery_list:
        g_img = read_image(g_img)
        g_img = transform(g_img)
        img_list.append(g_img)
    query_num = len(query_list)
    img_data = torch.Tensor([t.numpy() for t in img_list])

    model = model.to(device)
    model.eval()
    iter_n = len(img_list) // batch_size
    if len(img_list) % batch_size != 0:
        iter_n += 1
    all_feature = list()
    for i in range(iter_n):
        # print("batch ----%d----" % (i))
        batch_data = img_data[i * batch_size:(i + 1) * batch_size]
        with torch.no_grad():
            # batch_feature = model(batch_data).detach().cpu()

            ff = torch.FloatTensor(batch_data.size(0), 2048).zero_()
            for i in range(2):
                if i == 1:
                    batch_data = batch_data.index_select(3, torch.arange(batch_data.size(3) - 1, -1, -1).long())
                outputs = model(batch_data)
                f = outputs.data.cpu()
                ff = ff + f

            fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
            ff = ff.div(fnorm.expand_as(ff))

            all_feature.append(ff)
    all_feature = torch.cat(all_feature)
    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]
    if use_rerank:
        distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
    else:
        distmat = euclidean_dist(query_feat, gallery_feat)




    # distmat = euclidean_dist(query_feat, gallery_feat)
    cmc, mAP, _ = eval_func(distmat, np.array(qid_list), np.array(gid_list),
              np.array(qcid_list), np.array(gcid_list))
    print('Validation Result:')
    print(str(k1) + "  -  " + str(k2) + "  -  " + str(p))
    for r in [1, 5, 10]:

        print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    print('mAP: {:.2%}'.format(mAP))
    with open('re_rank.txt', 'a') as f:
        f.write(str(k1)+"  -  "+str(k2)+"  -  "+str(p) + "\n")
        for r in [1, 5, 10]:
            f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1])+"\n")
        f.write('mAP: {:.2%}'.format(mAP) + "\n")
        f.write('------------------------------------------\n')
        f.write('------------------------------------------\n')
        f.write('\n\n')
예제 #7
0
def pseudo_label_samples(model, query_list, gallery_list,  transform, batch_size, k1=20, k2=6, p=0.3):


    query_num = len(query_list)
    img_list = list()
    for q_img in query_list:
        q_img = read_image(q_img)
        q_img = transform(q_img)
        img_list.append(q_img)
    for g_img in gallery_list:
        g_img = read_image(g_img)
        g_img = transform(g_img)
        img_list.append(g_img)
    img_data = torch.Tensor([t.numpy() for t in img_list])
    model = model.to(device)
    model.eval()
    iter_n = len(img_list) // batch_size
    if len(img_list) % batch_size != 0:
        iter_n += 1
    all_feature = list()
    for i in range(iter_n):
        print("batch ----%d----" % (i))
        batch_data = img_data[i*batch_size:(i+1)*batch_size]
        with torch.no_grad():
            batch_feature = model(batch_data).detach().cpu()
            # ff = torch.FloatTensor(batch_data.size(0), 2048).zero_()
            # for i in range(2):
            #     if i == 1:
            #         batch_data = batch_data.index_select(3, torch.arange(batch_data.size(3) - 1, -1, -1).long())
            #
            #     outputs_1, outputs_2, outputs_3, outputs_4 = model(batch_data)
            #     outputs = torch.cat((outputs_1, outputs_2, outputs_3, outputs_4), 1)
            #     f = outputs.data.cpu()
            #     ff = ff + f
            #
            # fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
            # ff = ff.div(fnorm.expand_as(ff))
            all_feature.append(batch_feature)
    all_feature = torch.cat(all_feature)
    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]

    distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
    distmat = distmat
    num_q, num_g = distmat.shape
    indices = np.argsort(distmat, axis=1)
    max_200_indices = indices[:, :200]

    res_dict = dict()
    pseudo_res = {"q_imgs": list(), "g_imgs": list(), "probs": list()}
    error_prob = {"q_imgs": list(), "g_imgs": list(), "probs": list()}
    true_prob = {"q_imgs": list(), "g_imgs": list(), "probs": list()}
    for q_idx in range(num_q):
        print(query_list[q_idx])
        filename = query_list[q_idx][query_list[q_idx].rindex("\\")+1:]
        max_200_files = [gallery_list[i][gallery_list[i].rindex("\\")+1:] for i in max_200_indices[q_idx]]
        probs = [distmat[q_idx, i] for i in max_200_indices[q_idx]]

        if max_200_files[0].split("_")[0] != filename.split("_")[0]:
            error_prob["q_imgs"].append(filename)
            error_prob["g_imgs"].append(max_200_files[0])
            error_prob["probs"].append(probs[0])
        for i, prob in enumerate(probs):
            if probs[0]<0.1:
                true_prob["q_imgs"].append(filename)
                true_prob["g_imgs"].append(max_200_files[i])
                true_prob["probs"].append(probs[i])

        for g_filename, prob in zip(max_200_files, probs):
            pseudo_res["q_imgs"].append(filename)
            pseudo_res["g_imgs"].append(g_filename)
            pseudo_res["probs"].append(prob)


        res_dict[filename] = max_200_files

    columns = [u'q_imgs', u'g_imgs', u'probs']
    save_df = pd.DataFrame(pseudo_res,
                           columns=columns)
    save_df.to_csv('pseudo_res.csv')
    save_df = pd.DataFrame(error_prob,
                           columns=columns)
    save_df.to_csv('error_pseudo_res.csv')
    save_df = pd.DataFrame(true_prob,
                           columns=columns)
    save_df.to_csv('true_pseudo_res.csv')
예제 #8
0
def inference_samples(model,  transform, batch_size, feature_dim, k1=20, k2=6, p=0.3, use_rerank=False):
    query_list = list()
    with open(r'E:\data\reid\初赛A榜测试集\初赛A榜测试集/query_a_list.txt', 'r') as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            data = line.split(" ")
            image_name = data[0].split("/")[1]
            img_file = os.path.join(r'E:\data\reid\初赛A榜测试集\初赛A榜测试集\query_a', image_name)
            query_list.append(img_file)

    gallery_list = [os.path.join(r'E:\data\reid\初赛A榜测试集\初赛A榜测试集\gallery_a', x) for x in
                    os.listdir(r'E:\data\reid\初赛A榜测试集\初赛A榜测试集\gallery_a')]
    query_num = len(query_list)
    img_list = list()
    for q_img in query_list:
        q_img = read_image(q_img)
        q_img = transform(q_img)
        img_list.append(q_img)
    for g_img in gallery_list:
        g_img = read_image(g_img)
        g_img = transform(g_img)
        img_list.append(g_img)
    img_data = torch.Tensor([t.numpy() for t in img_list])
    model = model.to(device)
    model.eval()
    iter_n = len(img_list) // batch_size
    if len(img_list) % batch_size != 0:
        iter_n += 1
    all_feature = list()
    for i in range(iter_n):
        print("batch ----%d----" % (i))
        batch_data = img_data[i*batch_size:(i+1)*batch_size]
        with torch.no_grad():

            ff = torch.FloatTensor(batch_data.size(0), feature_dim).zero_()
            for i in range(2):
                if i == 1:
                    batch_data = batch_data.index_select(3, torch.arange(batch_data.size(3) - 1, -1, -1).long())

                outputs= model(batch_data)

                f = outputs.data.cpu()
                ff = ff + f

            fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
            ff = ff.div(fnorm.expand_as(ff))
            all_feature.append(ff)
    all_feature = torch.cat(all_feature)
    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]

    if use_rerank:
        print("use re_rank")
        distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
    else:
        distmat = euclidean_dist(query_feat, gallery_feat)
        distmat = distmat.numpy()
    num_q, num_g = distmat.shape
    indices = np.argsort(distmat, axis=1)

    max_200_indices = indices[:, :200]

    res_dict = dict()
    for q_idx in range(num_q):
        print(query_list[q_idx])
        filename = query_list[q_idx][query_list[q_idx].rindex("\\")+1:]
        max_200_files = [gallery_list[i][gallery_list[i].rindex("\\")+1:] for i in max_200_indices[q_idx]]
        res_dict[filename] = max_200_files

    with open(r'submission_A.json', 'w' ,encoding='utf-8') as f:
        json.dump(res_dict, f)
예제 #9
0
def inference_samples(args,
                      model,
                      transform,
                      batch_size,
                      query_txt,
                      query_dir,
                      gallery_dir,
                      save_dir,
                      k1=20,
                      k2=6,
                      p=0.3,
                      use_rerank=False,
                      use_flip=False,
                      max_rank=200,
                      bn_keys=[]):
    print("==>load data info..")
    if query_txt != "":
        query_list = list()
        with open(query_txt, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                data = line.split(" ")
                image_name = data[0].split("/")[1]
                img_file = os.path.join(query_dir, image_name)
                query_list.append(img_file)
    else:
        query_list = [
            os.path.join(query_dir, x) for x in os.listdir(query_dir)
        ]
    gallery_list = [
        os.path.join(gallery_dir, x) for x in os.listdir(gallery_dir)
    ]
    query_num = len(query_list)
    if args.save_fname != '':
        print(query_list[:10])
        query_list = sorted(query_list)
        print(query_list[:10])
        gallery_list = sorted(gallery_list)
    print("==>build dataloader..")
    image_set = ImageDataset(query_list + gallery_list, transform)
    dataloader = DataLoader(image_set,
                            sampler=SequentialSampler(image_set),
                            batch_size=batch_size,
                            num_workers=4)
    bn_dataloader = DataLoader(image_set,
                               sampler=RandomSampler(image_set),
                               batch_size=batch_size,
                               num_workers=4,
                               drop_last=True)

    print("==>model inference..")

    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           bn_dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, bn_dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats = []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = f
                    if i == 1:
                        ff[:, 2048:] = f
            else:
                ff = model(data).data.cpu()
            feats.append(ff)

    all_feature = torch.cat(feats, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)
        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        # fast by gpu
        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
    print('feature shape:', all_feature.size())
    if args.pseudo:
        print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
            args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints))

        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)

        if args.pseudo_visual:
            all_distmat, kdist = get_sparse_distmat(
                all_feature,
                eps=args.pseudo_eps + 0.1,
                len_slice=2000,
                use_gpu=True,
                dist_k=args.pseudo_minpoints)
            plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5)
            plt.savefig('test_kdist.png')
            plt.savefig(save_dir + 'test_kdist.png')
        else:
            all_distmat = get_sparse_distmat(all_feature,
                                             eps=args.pseudo_eps + 0.1,
                                             len_slice=2000,
                                             use_gpu=True)

        # print(all_distmat.todense()[0])

        pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps,
                                            args.pseudo_minpoints,
                                            args.pseudo_maxpoints,
                                            args.pseudo_algorithm)
        print("pseudo cost: {}s".format(time.time() - st))
        print("pseudo id cnt:", len(pseudolabels))
        print("pseudo img cnt:",
              len([x for k, v in pseudolabels.items() for x in v]))

        # # save
        all_list = query_list + gallery_list
        save_path = args.pseudo_savepath

        pid = args.pseudo_startid
        camid = 0
        for k, v in pseudolabels.items():
            os.makedirs(os.path.join(save_path, str(pid)), exist_ok=True)
            for _index in pseudolabels[k]:
                filename = all_list[_index].split("/")[-1]
                new_filename = str(pid) + "_c" + str(camid) + ".png"
                shutil.copy(all_list[_index],
                            os.path.join(save_path, str(pid), new_filename))
                camid += 1
            pid += 1

    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]

    if use_rerank:
        print("==>use re_rank")
        st = time.time()
        distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
        print("re_rank cost:", time.time() - st)
    else:
        st = time.time()
        weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
        print('==> using mgn_euclidean_dist')
        print('==> using weights:', weights)
        # distmat = euclidean_dist(query_feat, gallery_feat)
        distmat = None
        if use_flip:
            for i in range(2):
                dist = mgn_euclidean_dist(
                    query_feat[:, i * 2048:(i + 1) * 2048],
                    gallery_feat[:, i * 2048:(i + 1) * 2048],
                    norm=True,
                    weights=weights)
                if distmat is None:
                    distmat = dist / 2.0
                else:
                    distmat += dist / 2.0
        else:
            distmat = mgn_euclidean_dist(query_feat,
                                         gallery_feat,
                                         norm=True,
                                         weights=weights)

        print("euclidean_dist cost:", time.time() - st)

        distmat = distmat.numpy()

    num_q, num_g = distmat.shape
    print("==>saving..")
    if args.post:
        qfnames = [fname.split('/')[-1] for fname in query_list]
        gfnames = [fname.split('/')[-1] for fname in gallery_list]
        st = time.time()
        print("post json using top_per:", args.post_top_per)
        res_dict = get_post_json(distmat, qfnames, gfnames, args.post_top_per)
        print("post cost:", time.time() - st)
    else:
        # [todo] fast test
        print("==>sorting..")
        st = time.time()
        indices = np.argsort(distmat, axis=1)
        print("argsort cost:", time.time() - st)
        # print(indices[:2, :max_rank])
        # st = time.time()
        # indices = np.argpartition( distmat, range(1,max_rank+1))
        # print("argpartition cost:",time.time()-st)
        # print(indices[:2, :max_rank])

        max_200_indices = indices[:, :max_rank]
        res_dict = dict()
        for q_idx in range(num_q):
            filename = query_list[q_idx].split('/')[-1]
            max_200_files = [
                gallery_list[i].split('/')[-1] for i in max_200_indices[q_idx]
            ]
            res_dict[filename] = max_200_files
    if args.dba:
        save_fname = 'mgnsub_dba.json'
    elif args.aqe:
        save_fname = 'mgnsub_aqe.json'
    else:
        save_fname = 'mgnsub.json'
    if use_rerank:
        save_fname = 'rerank_' + save_fname
    if args.adabn:
        if args.adabn_all:
            save_fname = 'adabnall_' + save_fname
        else:
            save_fname = 'adabn_' + save_fname
    if use_flip:
        save_fname = 'flip_' + save_fname
    if args.post:
        save_fname = 'post_' + save_fname
    save_fname = args.save_fname + save_fname
    print('savefname:', save_fname)
    with open(save_dir + save_fname, 'w', encoding='utf-8') as f:
        json.dump(res_dict, f)
    with open(save_dir + save_fname.replace('.json', '.pkl'), 'wb') as fid:
        pickle.dump(distmat, fid, -1)
예제 #10
0
def inference_val(args,
                  model,
                  dataloader,
                  num_query,
                  save_dir,
                  k1=20,
                  k2=6,
                  p=0.3,
                  use_rerank=False,
                  use_flip=False,
                  n_randperm=0,
                  bn_keys=[]):
    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats, pids, camids = [], [], []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data, pid, camid, _ = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = f
                    if i == 1:
                        ff[:, 2048:] = f
                # ff = F.normalize(ff, p=2, dim=1)
            else:
                ff = model(data).data.cpu()
                # ff = F.normalize(ff, p=2, dim=1)

            feats.append(ff)
            pids.append(pid)
            camids.append(camid)
    all_feature = torch.cat(feats, dim=0)
    # all_feature = all_feature[:,:1024+512]
    pids = torch.cat(pids, dim=0)
    camids = torch.cat(camids, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)
        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)

        # norm_feature = F.normalize(all_feature, p=2, dim=1)
        # norm_feature = norm_feature.numpy()
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_nonorm_func(norm_feature[i],all_norm_feature_T=norm_feature.T ,all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)

        # part 2hour for val
        # part_norm_feat = []
        # for i in range(8):
        #     norm_feature = F.normalize(all_feature[:,i*256:(i+1)*256], p=2, dim=1)
        #     part_norm_feat.append(norm_feature)
        # norm_feature = torch.cat(part_norm_feat,dim=1)

        # norm_feature = norm_feature.numpy()
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(mgn_aqe_func(norm_feature[i],all_norm_feature_T=norm_feature.T ,all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)

        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
        # import pdb;pdb.set_trace()

    print('feature shape:', all_feature.size())

    #
    # for k1 in range(5,10,2):
    #     for k2 in range(2,5,1):
    #         for l in range(5,8):
    #             p = l*0.1

    if n_randperm <= 0:
        k2 = args.k2
        gallery_feat = all_feature[num_query:]
        query_feat = all_feature[:num_query]

        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = None
        if use_rerank:
            print('==> using rerank')
            distmat = re_rank(query_feat, gallery_feat, args.k1, args.k2, p)
        else:
            # if args.aqe:
            #     print('==> using euclidean_dist')
            #     distmat = euclidean_dist(query_feat, gallery_feat)
            # else:
            weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
            print('==> using mgn_euclidean_dist')
            print('==> using weights:', weights)
            if use_flip:
                for i in range(2):
                    dist = mgn_euclidean_dist(
                        query_feat[:, i * 2048:(i + 1) * 2048],
                        gallery_feat[:, i * 2048:(i + 1) * 2048],
                        norm=True,
                        weights=weights)
                    if distmat is None:
                        distmat = dist / 2.0
                    else:
                        distmat += dist / 2.0
            else:
                distmat = mgn_euclidean_dist(query_feat,
                                             gallery_feat,
                                             norm=True,
                                             weights=weights)

        cmc, mAP, _ = eval_func(distmat, query_pid.numpy(),
                                gallery_pid.numpy(), query_camid.numpy(),
                                gallery_camid.numpy())
    else:
        k2 = args.k2
        torch.manual_seed(0)
        cmc = 0
        mAP = 0
        for i in range(n_randperm):
            index = torch.randperm(all_feature.size()[0])

            query_feat = all_feature[index][:num_query]
            gallery_feat = all_feature[index][num_query:]

            query_pid = pids[index][:num_query]
            query_camid = camids[index][:num_query]

            gallery_pid = pids[index][num_query:]
            gallery_camid = camids[index][num_query:]

            if use_rerank:
                print('==> using rerank')
                st = time.time()
                distmat = re_rank(query_feat, gallery_feat, args.k1, args.k2,
                                  p)
                print("re_rank cost:", time.time() - st)

            else:
                print('==> using euclidean_dist')
                st = time.time()
                # distmat = euclidean_dist(query_feat, gallery_feat)
                # weights = [1,1,1,1,1,1,1,1]
                weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
                # weights = [2,1,1,1/2,1/2,1/3,1/3,1/3]

                print('==> using mgn_euclidean_dist')
                print('==> using weights:', weights)
                # distmat = euclidean_dist(query_feat, gallery_feat)
                distmat = None
                if use_flip:
                    for i in range(2):
                        dist = mgn_euclidean_dist(
                            query_feat[:, i * 2048:(i + 1) * 2048],
                            gallery_feat[:, i * 2048:(i + 1) * 2048],
                            norm=True,
                            weights=weights)
                        if distmat is None:
                            distmat = dist / 2.0
                        else:
                            distmat += dist / 2.0
                else:
                    distmat = mgn_euclidean_dist(query_feat,
                                                 gallery_feat,
                                                 norm=True,
                                                 weights=weights)
                print("euclidean_dist cost:", time.time() - st)

            _cmc, _mAP, _ = eval_func(distmat, query_pid.numpy(),
                                      gallery_pid.numpy(), query_camid.numpy(),
                                      gallery_camid.numpy())
            cmc += _cmc / n_randperm
            mAP += _mAP / n_randperm

    print('Validation Result:')
    if use_rerank:
        print(str(k1) + "  -  " + str(k2) + "  -  " + str(p))
    print('mAP: {:.2%}'.format(mAP))
    for r in [1, 5, 10]:
        print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    print('average of mAP and rank1: {:.2%}'.format((mAP + cmc[0]) / 2.0))

    with open(save_dir + 'eval.txt', 'a') as f:
        if use_rerank:
            f.write('==> using rerank\n')
            f.write(str(k1) + "  -  " + str(k2) + "  -  " + str(p) + "\n")
        else:
            f.write('==> using euclidean_dist\n')

        f.write('mAP: {:.2%}'.format(mAP) + "\n")
        for r in [1, 5, 10]:
            f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]) + "\n")
        f.write('average of mAP and rank1: {:.2%}\n'.format(
            (mAP + cmc[0]) / 2.0))

        f.write('------------------------------------------\n')
        f.write('------------------------------------------\n')
        f.write('\n\n')