Beispiel #1
0
def run(config, num_checkpoint, epoch_end, output_filename):
    task = get_task(config)
    preprocess_opt = task.get_preprocess_opt()
    dataloader = get_dataloader(config, 'train',
                                get_transform(config, 'dev', **preprocess_opt))

    model = task.get_model()
    checkpoints = get_checkpoints(config, num_checkpoint, epoch_end)
    print('checkpoints:')
    print('\n'.join(checkpoints))

    utils.checkpoint.load_checkpoint(model, None, checkpoints[0])
    for i, checkpoint in enumerate(checkpoints[1:]):
        model2 = get_task(config).get_model()
        last_epoch, _ = utils.checkpoint.load_checkpoint(
            model2, None, checkpoint)
        swa.moving_average(model, model2, 1. / (i + 2))

    with torch.no_grad():
        swa.bn_update(dataloader, model)

    output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint,
                                        last_epoch)
    print('save {}'.format(output_name))
    utils.checkpoint.save_checkpoint(
        config,
        model,
        None,
        0,
        0,
        name=output_name,
        weights_dict={'state_dict': model.state_dict()})
Beispiel #2
0
def run(config, num_checkpoint, epoch_end, output_filename):
    dataloader = get_dataloader(config, split='val', transform=None)

    model = get_model(config).cuda()
    checkpoints = get_checkpoints(config, num_checkpoint, epoch_end)

    utils.checkpoint.load_checkpoint(config, model, checkpoints[0])
    for i, checkpoint in enumerate(checkpoints[1:]):
        model2 = get_model(config).cuda()
        last_epoch, _, _ = utils.checkpoint.load_checkpoint(config, model2, checkpoint)
        swa.moving_average(model, model2, 1. / (i + 2))

    with torch.no_grad():
        swa.bn_update(dataloader, model)

    # output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint, last_epoch)
    # print('save {}'.format(output_name))
    utils.checkpoint.save_checkpoint(config, model, None, None, epoch_end,
                                     weights_dict={'state_dict': model.state_dict()},
                                     name=output_filename)
Beispiel #3
0
def run(args):
    df = pd.read_csv(args.df_path)
    df_train = df[df['fold'] != args.fold]

    model = get_model(args).cuda()
    dataloader = get_dataloader(args.data_dir, df_train, 'train',
                                args.pretrain, args.batch_size)
    checkpoints = get_checkpoints(args)

    checkpoint.load_checkpoint(
        args, model, None, checkpoint=checkpoints[0]
    )  # args, model, ckpt_name, checkpoint=None, optimizer=None
    for i, ckpt in enumerate(checkpoints[1:]):
        print(i, ckpt)
        model2 = get_model(args).cuda()
        last_epoch, _ = checkpoint.load_checkpoint(args,
                                                   model2,
                                                   None,
                                                   checkpoint=ckpt)
        if args.ema is None:
            swa.moving_average(model, model2, 1. / (i + 2))
        else:
            swa.moving_average(model, model2, args.ema)

    with torch.no_grad():
        swa.bn_update(dataloader, model)

    if args.ema is not None:
        output_name = f'model_ema_{len(checkpoints)}'
    else:
        output_name = f'model_swa_{len(checkpoints)}'

    print('save {}'.format(output_name))

    checkpoint.save_checkpoint(args,
                               model,
                               None,
                               0,
                               0,
                               name=output_name,
                               weights_dict={'state_dict': model.state_dict()})
Beispiel #4
0
def run(config, num_checkpoint, epoch_end, output_filename):
    dataloader = get_dataloader(config, 'train', get_transform(config, 'val'))

    model = get_model(config)
    if torch.cuda.is_available():
        model = model.cuda()
    checkpoints = get_checkpoints(config, num_checkpoint, epoch_end)

    utils.checkpoint.load_checkpoint(model, None, checkpoints[0])
    for i, checkpoint in enumerate(checkpoints[1:]):
        model2 = get_model(config)
        if torch.cuda.is_available():
            model2 = model2.cuda()
        last_epoch, _ = utils.checkpoint.load_checkpoint(model2, None, checkpoint)
        swa.moving_average(model, model2, 1. / (i + 2))

    with torch.no_grad():
        swa.bn_update(dataloader, model)

    output_name = '{}.{}.{:03d}'.format(output_filename, num_checkpoint, last_epoch)
    print('save {}'.format(output_name))
    utils.checkpoint.save_checkpoint(config, model, None, 0, 0,
                                     name=output_name,
                                     weights_dict={'state_dict': model.state_dict()})
Beispiel #5
0
def inference_samples(args,
                      model,
                      transform,
                      batch_size,
                      query_txt,
                      query_dir,
                      gallery_dir,
                      save_dir,
                      k1=20,
                      k2=6,
                      p=0.3,
                      use_rerank=False,
                      use_flip=False,
                      max_rank=200,
                      bn_keys=[]):
    print("==>load data info..")
    if query_txt != "":
        query_list = list()
        with open(query_txt, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                data = line.split(" ")
                image_name = data[0].split("/")[1]
                img_file = os.path.join(query_dir, image_name)
                query_list.append(img_file)
    else:
        query_list = [
            os.path.join(query_dir, x) for x in os.listdir(query_dir)
        ]
    gallery_list = [
        os.path.join(gallery_dir, x) for x in os.listdir(gallery_dir)
    ]
    query_num = len(query_list)
    if args.save_fname != '':
        print(query_list[:10])
        query_list = sorted(query_list)
        print(query_list[:10])
        gallery_list = sorted(gallery_list)
    print("==>build dataloader..")
    image_set = ImageDataset(query_list + gallery_list, transform)
    dataloader = DataLoader(image_set,
                            sampler=SequentialSampler(image_set),
                            batch_size=batch_size,
                            num_workers=4)
    bn_dataloader = DataLoader(image_set,
                               sampler=RandomSampler(image_set),
                               batch_size=batch_size,
                               num_workers=4,
                               drop_last=True)

    print("==>model inference..")

    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           bn_dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, bn_dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats = []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = f
                    if i == 1:
                        ff[:, 2048:] = f
            else:
                ff = model(data).data.cpu()
            feats.append(ff)

    all_feature = torch.cat(feats, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)
        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        # fast by gpu
        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
    print('feature shape:', all_feature.size())
    if args.pseudo:
        print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
            args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints))

        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)

        if args.pseudo_visual:
            all_distmat, kdist = get_sparse_distmat(
                all_feature,
                eps=args.pseudo_eps + 0.1,
                len_slice=2000,
                use_gpu=True,
                dist_k=args.pseudo_minpoints)
            plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5)
            plt.savefig('test_kdist.png')
            plt.savefig(save_dir + 'test_kdist.png')
        else:
            all_distmat = get_sparse_distmat(all_feature,
                                             eps=args.pseudo_eps + 0.1,
                                             len_slice=2000,
                                             use_gpu=True)

        # print(all_distmat.todense()[0])

        pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps,
                                            args.pseudo_minpoints,
                                            args.pseudo_maxpoints,
                                            args.pseudo_algorithm)
        print("pseudo cost: {}s".format(time.time() - st))
        print("pseudo id cnt:", len(pseudolabels))
        print("pseudo img cnt:",
              len([x for k, v in pseudolabels.items() for x in v]))

        # # save
        all_list = query_list + gallery_list
        save_path = args.pseudo_savepath

        pid = args.pseudo_startid
        camid = 0
        for k, v in pseudolabels.items():
            os.makedirs(os.path.join(save_path, str(pid)), exist_ok=True)
            for _index in pseudolabels[k]:
                filename = all_list[_index].split("/")[-1]
                new_filename = str(pid) + "_c" + str(camid) + ".png"
                shutil.copy(all_list[_index],
                            os.path.join(save_path, str(pid), new_filename))
                camid += 1
            pid += 1

    gallery_feat = all_feature[query_num:]
    query_feat = all_feature[:query_num]

    if use_rerank:
        print("==>use re_rank")
        st = time.time()
        distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
        print("re_rank cost:", time.time() - st)
    else:
        st = time.time()
        weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
        print('==> using mgn_euclidean_dist')
        print('==> using weights:', weights)
        # distmat = euclidean_dist(query_feat, gallery_feat)
        distmat = None
        if use_flip:
            for i in range(2):
                dist = mgn_euclidean_dist(
                    query_feat[:, i * 2048:(i + 1) * 2048],
                    gallery_feat[:, i * 2048:(i + 1) * 2048],
                    norm=True,
                    weights=weights)
                if distmat is None:
                    distmat = dist / 2.0
                else:
                    distmat += dist / 2.0
        else:
            distmat = mgn_euclidean_dist(query_feat,
                                         gallery_feat,
                                         norm=True,
                                         weights=weights)

        print("euclidean_dist cost:", time.time() - st)

        distmat = distmat.numpy()

    num_q, num_g = distmat.shape
    print("==>saving..")
    if args.post:
        qfnames = [fname.split('/')[-1] for fname in query_list]
        gfnames = [fname.split('/')[-1] for fname in gallery_list]
        st = time.time()
        print("post json using top_per:", args.post_top_per)
        res_dict = get_post_json(distmat, qfnames, gfnames, args.post_top_per)
        print("post cost:", time.time() - st)
    else:
        # [todo] fast test
        print("==>sorting..")
        st = time.time()
        indices = np.argsort(distmat, axis=1)
        print("argsort cost:", time.time() - st)
        # print(indices[:2, :max_rank])
        # st = time.time()
        # indices = np.argpartition( distmat, range(1,max_rank+1))
        # print("argpartition cost:",time.time()-st)
        # print(indices[:2, :max_rank])

        max_200_indices = indices[:, :max_rank]
        res_dict = dict()
        for q_idx in range(num_q):
            filename = query_list[q_idx].split('/')[-1]
            max_200_files = [
                gallery_list[i].split('/')[-1] for i in max_200_indices[q_idx]
            ]
            res_dict[filename] = max_200_files
    if args.dba:
        save_fname = 'mgnsub_dba.json'
    elif args.aqe:
        save_fname = 'mgnsub_aqe.json'
    else:
        save_fname = 'mgnsub.json'
    if use_rerank:
        save_fname = 'rerank_' + save_fname
    if args.adabn:
        if args.adabn_all:
            save_fname = 'adabnall_' + save_fname
        else:
            save_fname = 'adabn_' + save_fname
    if use_flip:
        save_fname = 'flip_' + save_fname
    if args.post:
        save_fname = 'post_' + save_fname
    save_fname = args.save_fname + save_fname
    print('savefname:', save_fname)
    with open(save_dir + save_fname, 'w', encoding='utf-8') as f:
        json.dump(res_dict, f)
    with open(save_dir + save_fname.replace('.json', '.pkl'), 'wb') as fid:
        pickle.dump(distmat, fid, -1)
Beispiel #6
0
def inference_val(args,
                  model,
                  dataloader,
                  num_query,
                  save_dir,
                  k1=20,
                  k2=6,
                  p=0.3,
                  use_rerank=False,
                  use_flip=False,
                  n_randperm=0,
                  bn_keys=[]):
    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats, pids, camids = [], [], []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data, pid, camid, _ = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = f
                    if i == 1:
                        ff[:, 2048:] = f
                # ff = F.normalize(ff, p=2, dim=1)
            else:
                ff = model(data).data.cpu()
                # ff = F.normalize(ff, p=2, dim=1)

            feats.append(ff)
            pids.append(pid)
            camids.append(camid)
    all_feature = torch.cat(feats, dim=0)
    # all_feature = all_feature[:,:1024+512]
    pids = torch.cat(pids, dim=0)
    camids = torch.cat(camids, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)
        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)

        # norm_feature = F.normalize(all_feature, p=2, dim=1)
        # norm_feature = norm_feature.numpy()
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_nonorm_func(norm_feature[i],all_norm_feature_T=norm_feature.T ,all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)

        # part 2hour for val
        # part_norm_feat = []
        # for i in range(8):
        #     norm_feature = F.normalize(all_feature[:,i*256:(i+1)*256], p=2, dim=1)
        #     part_norm_feat.append(norm_feature)
        # norm_feature = torch.cat(part_norm_feat,dim=1)

        # norm_feature = norm_feature.numpy()
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(mgn_aqe_func(norm_feature[i],all_norm_feature_T=norm_feature.T ,all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)

        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
        # import pdb;pdb.set_trace()

    print('feature shape:', all_feature.size())

    #
    # for k1 in range(5,10,2):
    #     for k2 in range(2,5,1):
    #         for l in range(5,8):
    #             p = l*0.1

    if n_randperm <= 0:
        k2 = args.k2
        gallery_feat = all_feature[num_query:]
        query_feat = all_feature[:num_query]

        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        distmat = None
        if use_rerank:
            print('==> using rerank')
            distmat = re_rank(query_feat, gallery_feat, args.k1, args.k2, p)
        else:
            # if args.aqe:
            #     print('==> using euclidean_dist')
            #     distmat = euclidean_dist(query_feat, gallery_feat)
            # else:
            weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
            print('==> using mgn_euclidean_dist')
            print('==> using weights:', weights)
            if use_flip:
                for i in range(2):
                    dist = mgn_euclidean_dist(
                        query_feat[:, i * 2048:(i + 1) * 2048],
                        gallery_feat[:, i * 2048:(i + 1) * 2048],
                        norm=True,
                        weights=weights)
                    if distmat is None:
                        distmat = dist / 2.0
                    else:
                        distmat += dist / 2.0
            else:
                distmat = mgn_euclidean_dist(query_feat,
                                             gallery_feat,
                                             norm=True,
                                             weights=weights)

        cmc, mAP, _ = eval_func(distmat, query_pid.numpy(),
                                gallery_pid.numpy(), query_camid.numpy(),
                                gallery_camid.numpy())
    else:
        k2 = args.k2
        torch.manual_seed(0)
        cmc = 0
        mAP = 0
        for i in range(n_randperm):
            index = torch.randperm(all_feature.size()[0])

            query_feat = all_feature[index][:num_query]
            gallery_feat = all_feature[index][num_query:]

            query_pid = pids[index][:num_query]
            query_camid = camids[index][:num_query]

            gallery_pid = pids[index][num_query:]
            gallery_camid = camids[index][num_query:]

            if use_rerank:
                print('==> using rerank')
                st = time.time()
                distmat = re_rank(query_feat, gallery_feat, args.k1, args.k2,
                                  p)
                print("re_rank cost:", time.time() - st)

            else:
                print('==> using euclidean_dist')
                st = time.time()
                # distmat = euclidean_dist(query_feat, gallery_feat)
                # weights = [1,1,1,1,1,1,1,1]
                weights = [1, 1, 1, 1 / 2, 1 / 2, 1 / 3, 1 / 3, 1 / 3]
                # weights = [2,1,1,1/2,1/2,1/3,1/3,1/3]

                print('==> using mgn_euclidean_dist')
                print('==> using weights:', weights)
                # distmat = euclidean_dist(query_feat, gallery_feat)
                distmat = None
                if use_flip:
                    for i in range(2):
                        dist = mgn_euclidean_dist(
                            query_feat[:, i * 2048:(i + 1) * 2048],
                            gallery_feat[:, i * 2048:(i + 1) * 2048],
                            norm=True,
                            weights=weights)
                        if distmat is None:
                            distmat = dist / 2.0
                        else:
                            distmat += dist / 2.0
                else:
                    distmat = mgn_euclidean_dist(query_feat,
                                                 gallery_feat,
                                                 norm=True,
                                                 weights=weights)
                print("euclidean_dist cost:", time.time() - st)

            _cmc, _mAP, _ = eval_func(distmat, query_pid.numpy(),
                                      gallery_pid.numpy(), query_camid.numpy(),
                                      gallery_camid.numpy())
            cmc += _cmc / n_randperm
            mAP += _mAP / n_randperm

    print('Validation Result:')
    if use_rerank:
        print(str(k1) + "  -  " + str(k2) + "  -  " + str(p))
    print('mAP: {:.2%}'.format(mAP))
    for r in [1, 5, 10]:
        print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    print('average of mAP and rank1: {:.2%}'.format((mAP + cmc[0]) / 2.0))

    with open(save_dir + 'eval.txt', 'a') as f:
        if use_rerank:
            f.write('==> using rerank\n')
            f.write(str(k1) + "  -  " + str(k2) + "  -  " + str(p) + "\n")
        else:
            f.write('==> using euclidean_dist\n')

        f.write('mAP: {:.2%}'.format(mAP) + "\n")
        for r in [1, 5, 10]:
            f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]) + "\n")
        f.write('average of mAP and rank1: {:.2%}\n'.format(
            (mAP + cmc[0]) / 2.0))

        f.write('------------------------------------------\n')
        f.write('------------------------------------------\n')
        f.write('\n\n')
Beispiel #7
0
def inference_samples(args,
                      model,
                      transform,
                      batch_size,
                      query_txt,
                      query_dir,
                      gallery_dir,
                      save_dir,
                      k1=20,
                      k2=6,
                      p=0.3,
                      use_rerank=False,
                      use_flip=False,
                      max_rank=200,
                      bn_keys=[]):
    print("==>load data info..")
    if query_txt != "":
        query_list = list()
        with open(query_txt, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                data = line.split(" ")
                image_name = data[0].split("/")[1]
                img_file = os.path.join(query_dir, image_name)
                query_list.append(img_file)
    else:
        query_list = [
            os.path.join(query_dir, x) for x in os.listdir(query_dir)
        ]
    gallery_list = [
        os.path.join(gallery_dir, x) for x in os.listdir(gallery_dir)
    ]
    query_num = len(query_list)
    if args.save_fname != '':
        print(query_list[:10])
        query_list = sorted(query_list)
        print(query_list[:10])
        gallery_list = sorted(gallery_list)
    print("==>build dataloader..")
    image_set = ImageDataset(query_list + gallery_list, transform)
    dataloader = DataLoader(image_set,
                            sampler=SequentialSampler(image_set),
                            batch_size=batch_size,
                            num_workers=6)
    bn_dataloader = DataLoader(image_set,
                               sampler=RandomSampler(image_set),
                               batch_size=batch_size,
                               num_workers=6,
                               drop_last=True)

    print("==>model inference..")

    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           bn_dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, bn_dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats = []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = F.normalize(f, p=2, dim=1)
                    if i == 1:
                        ff[:, 2048:] = F.normalize(f, p=2, dim=1)
                ff = F.normalize(ff, p=2, dim=1)
            else:
                ff = model(data).data.cpu()
                ff = F.normalize(ff, p=2, dim=1)
            feats.append(ff)

    all_feature = torch.cat(feats, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        # fast by gpu
        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
    print('feature shape:', all_feature.size())
    if args.pseudo:
        print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
            args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints))

        st = time.time()

        all_feature = F.normalize(all_feature, p=2, dim=1)
        if args.pseudo_hist:
            print("==> predict histlabel...")

            img_filenames = query_list + gallery_list
            img_idx = list(range(len(img_filenames)))
            imgs = {'filename': img_filenames, 'identity': img_idx}
            df_img = pd.DataFrame(imgs)
            hist_labels = mmcv.track_parallel_progress(img_hist_predictor,
                                                       df_img['filename'], 6)
            print("hist label describe..")
            unique_hist_labels = sorted(list(set(hist_labels)))

            hl_idx = []

            hl_query_infos = []
            hl_gallery_infos = []

            for label_idx in range(len(unique_hist_labels)):
                hl_query_infos.append([])
                hl_gallery_infos.append([])
                hl_idx.append([])
            with tqdm(total=len(img_filenames)) as pbar:
                for idx, info in enumerate(img_filenames):
                    for label_idx in range(len(unique_hist_labels)):
                        if hist_labels[idx] == unique_hist_labels[label_idx]:
                            if idx < len(query_list):
                                hl_query_infos[label_idx].append(info)
                            else:
                                hl_gallery_infos[label_idx].append(info)
                            hl_idx[label_idx].append(idx)
                    pbar.update(1)
            for label_idx in range(len(unique_hist_labels)):
                print('hist_label:', unique_hist_labels[label_idx],
                      ' query number:', len(hl_query_infos[label_idx]))
                print('hist_label:', unique_hist_labels[label_idx],
                      ' gallery number:', len(hl_gallery_infos[label_idx]))
                print(
                    'hist_label:', unique_hist_labels[label_idx],
                    ' q+g number:',
                    len(hl_query_infos[label_idx]) +
                    len(hl_gallery_infos[label_idx]))
                print('hist_label:', unique_hist_labels[label_idx],
                      ' idx q+g number:', len(hl_idx[label_idx]))
            # pseudo
            pid = args.pseudo_startid
            camid = 0
            all_list = query_list + gallery_list
            save_path = args.pseudo_savepath
            pseudo_eps = args.pseudo_eps
            pseudo_minpoints = args.pseudo_minpoints
            for label_idx in range(len(unique_hist_labels)):
                # if label_idx == 0:
                #     pseudo_eps = 0.6
                # else:
                #     pseudo_eps = 0.75
                if label_idx == 0:
                    pseudo_eps = 0.65
                else:
                    pseudo_eps = 0.80

                feature = all_feature[hl_idx[label_idx]]
                img_list = [all_list[idx] for idx in hl_idx[label_idx]]

                print("==> get sparse distmat!")
                if args.pseudo_visual:
                    all_distmat, kdist = get_sparse_distmat(
                        feature,
                        eps=pseudo_eps + 0.05,
                        len_slice=2000,
                        use_gpu=True,
                        dist_k=pseudo_minpoints)
                    plt.plot(list(range(len(kdist))),
                             np.sort(kdist),
                             linewidth=0.5)
                    plt.savefig('test_kdist_hl{}_eps{}_{}.png'.format(
                        label_idx, pseudo_eps, pseudo_minpoints))
                    plt.savefig(save_dir +
                                'test_kdist_hl{}_eps{}_{}.png'.format(
                                    label_idx, pseudo_eps, pseudo_minpoints))
                else:
                    all_distmat = get_sparse_distmat(feature,
                                                     eps=pseudo_eps + 0.05,
                                                     len_slice=2000,
                                                     use_gpu=True)

                print("==> predict pseudo label!")
                pseudolabels = predict_pseudo_label(all_distmat, pseudo_eps,
                                                    pseudo_minpoints,
                                                    args.pseudo_maxpoints,
                                                    args.pseudo_algorithm)
                print(
                    "==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
                        pseudo_eps, pseudo_minpoints, args.pseudo_maxpoints))
                print("pseudo cost: {}s".format(time.time() - st))
                print("pseudo id cnt:", len(pseudolabels))
                print("pseudo img cnt:",
                      len([x for k, v in pseudolabels.items() for x in v]))

                if label_idx == 0:
                    sf = 1
                else:
                    sf = 1
                sample_id_cnt = 0
                sample_file_cnt = 0
                nignore_query = 0
                for i, (k, v) in enumerate(pseudolabels.items()):
                    if i % sf != 0:
                        continue
                    # query_cnt = 0
                    # for _index in pseudolabels[k]:
                    #     if _index<len(query_list):
                    #         query_cnt += 1
                    # if query_cnt>=2:
                    #     nignore_query += 1
                    #     continue
                    os.makedirs(os.path.join(save_path, str(pid)),
                                exist_ok=True)

                    for _index in pseudolabels[k]:
                        filename = img_list[_index].split("/")[-1]
                        new_filename = str(pid) + "_c" + str(camid) + ".png"
                        shutil.copy(
                            img_list[_index],
                            os.path.join(save_path, str(pid), new_filename))
                        camid += 1
                        sample_file_cnt += 1
                    sample_id_cnt += 1
                    pid += 1
                print("pseudo ignore id cnt:", nignore_query)
                print("sample id cnt:", sample_id_cnt)
                print("sample file cnt:", sample_file_cnt)
        else:
            if args.pseudo_visual:
                all_distmat, kdist = get_sparse_distmat(
                    all_feature,
                    eps=args.pseudo_eps + 0.05,
                    len_slice=2000,
                    use_gpu=True,
                    dist_k=args.pseudo_minpoints)
                plt.plot(list(range(len(kdist))),
                         np.sort(kdist),
                         linewidth=0.5)
                plt.savefig('test_kdist.png')
                plt.savefig(save_dir + 'test_kdist.png')
            else:
                all_distmat = get_sparse_distmat(all_feature,
                                                 eps=args.pseudo_eps + 0.05,
                                                 len_slice=2000,
                                                 use_gpu=True)

            # print(all_distmat.todense()[0])

            pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps,
                                                args.pseudo_minpoints,
                                                args.pseudo_maxpoints,
                                                args.pseudo_algorithm)
            print("pseudo cost: {}s".format(time.time() - st))
            print("pseudo id cnt:", len(pseudolabels))
            print("pseudo img cnt:",
                  len([x for k, v in pseudolabels.items() for x in v]))

            # # save
            all_list = query_list + gallery_list
            save_path = args.pseudo_savepath

            pid = args.pseudo_startid
            camid = 0
            nignore_query = 0
            for k, v in pseudolabels.items():
                os.makedirs(os.path.join(save_path, str(pid)), exist_ok=True)
                # [fileter]
                query_cnt = 0
                for _index in pseudolabels[k]:
                    if _index < len(query_list):
                        query_cnt += 1
                if query_cnt >= 4:
                    nignore_query += 1
                    continue
                for _index in pseudolabels[k]:
                    filename = all_list[_index].split("/")[-1]
                    new_filename = str(pid) + "_c" + str(camid) + ".png"
                    shutil.copy(
                        all_list[_index],
                        os.path.join(save_path, str(pid), new_filename))
                    camid += 1
                pid += 1
            print("pseudo ignore id cnt:", nignore_query)
    else:
        gallery_feat = all_feature[query_num:]
        query_feat = all_feature[:query_num]

        if use_rerank:
            print("==>use re_rank")
            st = time.time()
            k2 = args.k2
            # distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
            num_query = len(query_feat)
            print("using k1:{} k2:{} lambda:{}".format(args.k1, args.k2, p))
            distmat = re_ranking_batch_gpu(
                torch.cat([query_feat, gallery_feat], dim=0), num_query,
                args.k1, args.k2, p)

            print("re_rank cost:", time.time() - st)

        else:
            print("==>use euclidean_dist")
            st = time.time()
            distmat = euclidean_dist(query_feat, gallery_feat)
            print("euclidean_dist cost:", time.time() - st)

            distmat = distmat.numpy()

        num_q, num_g = distmat.shape
        print("==>saving..")
        if args.post:
            qfnames = [fname.split('/')[-1] for fname in query_list]
            gfnames = [fname.split('/')[-1] for fname in gallery_list]
            st = time.time()
            print("post json using top_per:", args.post_top_per)
            res_dict = get_post_json(distmat, qfnames, gfnames,
                                     args.post_top_per)
            print("post cost:", time.time() - st)
        else:
            # [todo] fast test
            print("==>sorting..")
            st = time.time()
            indices = np.argsort(distmat, axis=1)
            print("argsort cost:", time.time() - st)
            # print(indices[:2, :max_rank])
            # st = time.time()
            # indices = np.argpartition( distmat, range(1,max_rank+1))
            # print("argpartition cost:",time.time()-st)
            # print(indices[:2, :max_rank])

            max_200_indices = indices[:, :max_rank]
            res_dict = dict()
            for q_idx in range(num_q):
                filename = query_list[q_idx].split('/')[-1]
                max_200_files = [
                    gallery_list[i].split('/')[-1]
                    for i in max_200_indices[q_idx]
                ]
                res_dict[filename] = max_200_files
        if args.dba:
            save_fname = 'sub_dba.json'
        elif args.aqe:
            save_fname = 'sub_aqe.json'
        else:
            save_fname = 'sub.json'
        if use_rerank:
            save_fname = 'rerank_' + save_fname
        if args.adabn:
            if args.adabn_all:
                save_fname = 'adabnall_' + save_fname
            else:
                save_fname = 'adabn_' + save_fname
        if use_flip:
            save_fname = 'flip_' + save_fname
        if args.post:
            save_fname = 'post_' + save_fname
        save_fname = args.save_fname + save_fname
        print('savefname:', save_fname)
        with open(save_dir + save_fname, 'w', encoding='utf-8') as f:
            json.dump(res_dict, f)
        with open(save_dir + save_fname.replace('.json', '.pkl'), 'wb') as fid:
            pickle.dump(distmat, fid, -1)
Beispiel #8
0
def inference_val(args,
                  model,
                  dataloader,
                  num_query,
                  save_dir,
                  k1=20,
                  k2=6,
                  p=0.3,
                  use_rerank=False,
                  use_flip=False,
                  n_randperm=0,
                  bn_keys=[]):
    model = model.to(device)
    if args.adabn and len(bn_keys) > 0:
        print("==> using adabn for specific bn layers")
        specific_bn_update(model,
                           dataloader,
                           cumulative=not args.adabn_emv,
                           bn_keys=bn_keys)
    elif args.adabn:
        print("==> using adabn for all bn layers")
        bn_update(model, dataloader, cumulative=not args.adabn_emv)

    model.eval()
    feats, pids, camids = [], [], []
    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader)):
            data, pid, camid, _ = batch
            data = data.cuda()

            if use_flip:
                ff = torch.FloatTensor(data.size(0), 2048 * 2).zero_()
                for i in range(2):
                    # flip
                    if i == 1:
                        data = data.index_select(
                            3,
                            torch.arange(data.size(3) - 1, -1,
                                         -1).long().to('cuda'))
                    outputs = model(data)
                    f = outputs.data.cpu()
                    # cat
                    if i == 0:
                        ff[:, :2048] = F.normalize(f, p=2, dim=1)
                    if i == 1:
                        ff[:, 2048:] = F.normalize(f, p=2, dim=1)
                ff = F.normalize(ff, p=2, dim=1)
                # ff = torch.FloatTensor(data.size(0), 2048).zero_()
                # for i in range(2):
                #     if i == 1:
                #         data = data.index_select(3, torch.arange(data.size(3) - 1, -1, -1).long().to('cuda'))
                #     outputs = model(data)
                #     f = outputs.data.cpu()
                #     ff = ff + f
                # fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                # ff = ff.div(fnorm.expand_as(ff))
            else:
                ff = model(data).data.cpu()
                ff = F.normalize(ff, p=2, dim=1)

            feats.append(ff)
            pids.append(pid)
            camids.append(camid)
    all_feature = torch.cat(feats, dim=0)
    # all_feature = all_feature[:,:1024+512]
    pids = torch.cat(pids, dim=0)
    camids = torch.cat(camids, dim=0)

    # DBA
    if args.dba:
        k2 = args.dba_k2
        alpha = args.dba_alpha
        assert alpha < 0
        print("==>using DBA k2:{} alpha:{}".format(k2, alpha))
        st = time.time()

        # [todo] heap sort
        distmat = euclidean_dist(all_feature, all_feature)
        # initial_rank = distmat.numpy().argsort(axis=1)
        initial_rank = np.argpartition(distmat.numpy(), range(1, k2 + 1))

        all_feature = all_feature.numpy()

        V_qe = np.zeros_like(all_feature, dtype=np.float32)
        weights = np.logspace(0, alpha, k2).reshape((-1, 1))
        with tqdm(total=len(all_feature)) as pbar:
            for i in range(len(all_feature)):
                V_qe[i, :] = np.mean(all_feature[initial_rank[i, :k2], :] *
                                     weights,
                                     axis=0)
                pbar.update(1)
        # import pdb;pdb.set_trace()
        all_feature = V_qe
        del V_qe
        all_feature = torch.from_numpy(all_feature)

        fnorm = torch.norm(all_feature, p=2, dim=1, keepdim=True)
        all_feature = all_feature.div(fnorm.expand_as(all_feature))
        print("DBA cost:", time.time() - st)
    # aQE: weight query expansion
    if args.aqe:
        k2 = args.aqe_k2
        alpha = args.aqe_alpha

        print("==>using weight query expansion k2: {} alpha: {}".format(
            k2, alpha))
        st = time.time()

        # # [todo] remove norma; normalize is used to to make sure the similiar one is itself
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        # sims = torch.mm(all_feature, all_feature.t()).numpy()

        # # [todo] heap sort
        # # initial_rank = sims.argsort(axis=1)[:,::-1]
        # initial_rank = np.argpartition(-sims,range(1,k2+1))

        # all_feature = all_feature.numpy()

        # V_qe = np.zeros_like(all_feature,dtype=np.float32)

        # # [todo] update query feature only?
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         # get weights from similarity
        #         weights = sims[i,initial_rank[i,:k2]].reshape((-1,1))
        #         # weights = (weights-weights.min())/(weights.max()-weights.min())
        #         weights = np.power(weights,alpha)
        #         # import pdb;pdb.set_trace()

        #         V_qe[i,:] = np.mean(all_feature[initial_rank[i,:k2],:]*weights,axis=0)
        #         pbar.update(1)
        # # import pdb;pdb.set_trace()
        # all_feature = V_qe
        # del V_qe
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)

        # func = functools.partial(aqe_func,all_feature=all_feature,k2=k2,alpha=alpha)
        # all_features = mmcv.track_parallel_progress(func, all_feature, 6)

        # cpu
        # all_feature = F.normalize(all_feature, p=2, dim=1)
        # all_feature = all_feature.numpy()

        # all_features = []
        # with tqdm(total=len(all_feature)) as pbar:
        #     for i in range(len(all_feature)):
        #         all_features.append(aqe_func(all_feature[i],all_feature=all_feature,k2=k2,alpha=alpha))
        #         pbar.update(1)
        # all_feature = np.stack(all_features,axis=0)
        # all_feature = torch.from_numpy(all_feature)
        # all_feature = F.normalize(all_feature, p=2, dim=1)

        all_feature = aqe_func_gpu(all_feature, k2, alpha, len_slice=2000)
        print("aQE cost:", time.time() - st)
        # import pdb;pdb.set_trace()

    if args.pseudo:
        print("==> using pseudo eps:{} minPoints:{} maxpoints:{}".format(
            args.pseudo_eps, args.pseudo_minpoints, args.pseudo_maxpoints))
        st = time.time()
        # cal sparse distmat
        all_feature = F.normalize(all_feature, p=2, dim=1)

        # all_distmat = euclidean_dist(all_feature, all_feature).numpy()
        # print(all_distmat[0])
        # pred1 = predict_pseudo_label(all_distmat,args.pseudo_eps,args.pseudo_minpoints,args.pseudo_maxpoints,args.pseudo_algorithm)
        # print(list(pred1.keys())[:10])

        if args.pseudo_visual:
            all_distmat, kdist = get_sparse_distmat(
                all_feature,
                eps=args.pseudo_eps + 0.1,
                len_slice=2000,
                use_gpu=True,
                dist_k=args.pseudo_minpoints)
            plt.plot(list(range(len(kdist))), np.sort(kdist), linewidth=0.5)
            plt.savefig('eval_kdist.png')
            plt.savefig(save_dir + 'eval_kdist.png')
        else:
            all_distmat = get_sparse_distmat(all_feature,
                                             eps=args.pseudo_eps + 0.1,
                                             len_slice=2000,
                                             use_gpu=True)

        # print(all_distmat.todense()[0])

        pseudolabels = predict_pseudo_label(all_distmat, args.pseudo_eps,
                                            args.pseudo_minpoints,
                                            args.pseudo_maxpoints,
                                            args.pseudo_algorithm)
        print("pseudo cost: {}s".format(time.time() - st))
        print("pseudo id cnt:", len(pseudolabels))
        print("pseudo img cnt:",
              len([x for k, v in pseudolabels.items() for x in v]))
        print("pseudo cost: {}s".format(time.time() - st))
        # print(list(pred.keys())[:10])
    print('feature shape:', all_feature.size())

    #
    # for k1 in range(5,10,2):
    #     for k2 in range(2,5,1):
    #         for l in range(5,8):
    #             p = l*0.1

    if n_randperm <= 0:
        k2 = args.k2
        gallery_feat = all_feature[num_query:]
        query_feat = all_feature[:num_query]

        query_pid = pids[:num_query]
        query_camid = camids[:num_query]

        gallery_pid = pids[num_query:]
        gallery_camid = camids[num_query:]

        if use_rerank:
            print('==> using rerank')
            # distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
            distmat = re_ranking_batch_gpu(
                torch.cat([query_feat, gallery_feat], dim=0), num_query,
                args.k1, args.k2, p)
        else:
            print('==> using euclidean_dist')
            distmat = euclidean_dist(query_feat, gallery_feat)

        cmc, mAP, _ = eval_func(distmat, query_pid.numpy(),
                                gallery_pid.numpy(), query_camid.numpy(),
                                gallery_camid.numpy())
    else:
        k2 = args.k2
        torch.manual_seed(0)
        cmc = 0
        mAP = 0
        for i in range(n_randperm):
            index = torch.randperm(all_feature.size()[0])

            query_feat = all_feature[index][:num_query]
            gallery_feat = all_feature[index][num_query:]

            query_pid = pids[index][:num_query]
            query_camid = camids[index][:num_query]

            gallery_pid = pids[index][num_query:]
            gallery_camid = camids[index][num_query:]

            if use_rerank:
                print('==> using rerank')
                st = time.time()
                # distmat = re_rank(query_feat, gallery_feat, k1, k2, p)
                distmat = re_ranking_batch_gpu(
                    torch.cat([query_feat, gallery_feat], dim=0), num_query,
                    args.k1, args.k2, p)

                print("re_rank cost:", time.time() - st)

            else:
                print('==> using euclidean_dist')
                st = time.time()
                distmat = euclidean_dist(query_feat, gallery_feat)
                print("euclidean_dist cost:", time.time() - st)

            _cmc, _mAP, _ = eval_func(distmat, query_pid.numpy(),
                                      gallery_pid.numpy(), query_camid.numpy(),
                                      gallery_camid.numpy())
            cmc += _cmc / n_randperm
            mAP += _mAP / n_randperm

    print('Validation Result:')
    if use_rerank:
        print(str(k1) + "  -  " + str(k2) + "  -  " + str(p))
    print('mAP: {:.2%}'.format(mAP))
    for r in [1, 5, 10]:
        print('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]))
    print('average of mAP and rank1: {:.2%}'.format((mAP + cmc[0]) / 2.0))

    with open(save_dir + 'eval.txt', 'a') as f:
        if use_rerank:
            f.write('==> using rerank\n')
            f.write(str(k1) + "  -  " + str(k2) + "  -  " + str(p) + "\n")
        else:
            f.write('==> using euclidean_dist\n')

        f.write('mAP: {:.2%}'.format(mAP) + "\n")
        for r in [1, 5, 10]:
            f.write('CMC Rank-{}: {:.2%}'.format(r, cmc[r - 1]) + "\n")
        f.write('average of mAP and rank1: {:.2%}\n'.format(
            (mAP + cmc[0]) / 2.0))

        f.write('------------------------------------------\n')
        f.write('------------------------------------------\n')
        f.write('\n\n')