def forward(self, source_features, target_features):

        # group each images of the same identity together
        instances = self.instances
        batch_size = self.batch_size
        feature_size = target_features.shape[1]  # 2048
        t = torch.reshape(target_features, (int(batch_size / instances), instances, feature_size))

            #  and compute bc/wc euclidean distance
        wct = compute_distance_matrix(t[0], t[0])
        bct = compute_distance_matrix(t[0], t[1])
        for i in t[1:]:
            wct = torch.cat((wct, compute_distance_matrix(i, i)))
            for j in t:
                if not torch.equal(i, j): # if j is not i:
                    bct = torch.cat((bct, compute_distance_matrix(i, j)))

        s = torch.reshape(source_features, (int(batch_size / instances), instances, feature_size))
        wcs = compute_distance_matrix(s[0], s[0])
        bcs = compute_distance_matrix(s[0], s[1])
        for i in s[1:]:
            wcs = torch.cat((wcs, compute_distance_matrix(i, i)))
            for j in s:
                if not torch.equal(i, j): # if j is not i:
                    bcs = torch.cat((bcs, compute_distance_matrix(i, j)))

        # We want to modify only target distribution
        bcs = bcs.detach()
        wcs = wcs.detach()

        return self.mmd_loss(wcs, wct), self.mmd_loss(bcs, bct), self.mmd_loss(source_features, target_features)
Beispiel #2
0
    def run(self, compaer_img, origin_f, origin_name, dist_metric='cosine'):
        '''
			Args:
				- compaer_img: single image that want to compare with base image 
							 type: BGR          
		'''
        compare_f = self._extract_feature(self.model, compaer_img).data.cpu()

        if (self.is_normalize_f):
            #print('Normalzing features with L2 norm ...')
            compare_f = F.normalize(compare_f, p=2, dim=1)

        distmat = metrics.compute_distance_matrix(compare_f,
                                                  origin_f,
                                                  metric=dist_metric)
        distmat = distmat.numpy()
        dist_list = distmat.tolist()[0]  # to list

        #print("dist list:", dist_list, origin_name)

        #top_id = distmat.tolist()[0].index(min(distmat.tolist()[0]))
        top_id = dist_list.index(min(dist_list))
        if (min(dist_list) < 0.30):
            identify_name = origin_name[top_id]
        else:
            identify_name = "Unknown"

        return identify_name, min(dist_list)
Beispiel #3
0
    def get_features(self, query_imgs, gallery_imgs):
        qf = []
        for img in query_imgs:
            img = Image.fromarray(img.astype('uint8')).convert('RGB')
            img = self.transform_te(img)
            img = torch.unsqueeze(img, 0)
            img = img.cuda()
            features = self._extract_features(img)
            features = features.data.cpu() #tensor shape=1x2048
            qf.append(features)
        qf = torch.cat(qf, 0)

        gf = []
        for img in gallery_imgs:
            img = Image.fromarray(img.astype('uint8')).convert('RGB')
            img = self.transform_te(img)
            img = torch.unsqueeze(img, 0)
            img = img.cuda()
            features = self._extract_features(img)
            features = features.data.cpu() #tensor shape=1x2048
            gf.append(features)
        gf = torch.cat(gf, 0)
        distmat = metrics.compute_distance_matrix(qf, gf, self.dist_metric)
        # print(distmat.shape)
        return distmat.numpy()
 def get_local_correl(self,local_feat):
     
     local_feat =local_feat.reshape(local_feat.size(0)*6, 2048) 
     final_dis_mat = compute_distance_matrix(local_feat,local_feat)
       
     
      
     return final_dis_mat
Beispiel #5
0
    def get_local_correl(self, local_feat):
        final_dis_mat = torch.zeros(
            (local_feat.size(0), local_feat.size(1) * local_feat.size(1)))
        for i in range(local_feat.size(0)):
            temp = local_feat[i]
            dist_matrix = compute_distance_matrix(temp, temp)
            dist_matrix = dist_matrix.reshape(-1)
            final_dis_mat[i] = dist_matrix

        return final_dis_mat.cuda()
Beispiel #6
0
def test(model, queryloader, galleryloader, dist_metric, normalize_feature):
    batch_time = AverageMeter()
    model.eval()
    with torch.no_grad():
        print('Extracting features from query set ...')
        qf, q_names = [], []
        for batch_idx, (imgs, img_names) in enumerate(queryloader):
            imgs = imgs.cuda()
            end = time.time()
            features = model(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            qf.append(features)
            q_names.extend(img_names)
        qf = torch.cat(qf, 0)
        print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

        print('Extracting features from gallery set ...')
        gf, g_names = [], []
        for batch_idx, (imgs, img_names) in enumerate(galleryloader):
            imgs = imgs.cuda()
            end = time.time()
            features = model(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            gf.append(features)
            g_names.extend(img_names)
        gf = torch.cat(gf, 0)
        print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

    print('Speed: {:.4f} sec/batch'.format(batch_time.avg))

    if normalize_feature:
        print('Normalzing features with L2 norm ...')
        qf = F.normalize(qf, p=2, dim=1)
        gf = F.normalize(gf, p=2, dim=1)

    print('Computing distance matrix with metric={} ...'.format(dist_metric))
    distmat = compute_distance_matrix(qf, gf, dist_metric)
    distmat = distmat.numpy()
    indices = np.argsort(distmat, axis=1)

    rank = {}
    for q_idx in range(qf.size(0)):
        q_name = q_names[q_idx]
        im_list = []
        for i in range(200):
            g_idx = indices[q_idx, i]
            g_name = g_names[g_idx]
            im_list.append(g_name)
        rank[q_name] = im_list

    with open("result.json", "w") as f:
        json.dump(rank, f)
        print('done')
Beispiel #7
0
def evaluate(model, queryloader, galleryloader, dist_metric,
             normalize_feature):
    batch_time = AverageMeter()
    model.eval()
    with torch.no_grad():
        print('Extracting features from query set ...')
        qf, q_pids = [], []
        for batch_idx, (imgs, pids) in enumerate(queryloader):
            imgs = imgs.cuda()
            end = time.time()
            features = model(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            qf.append(features)
            q_pids.extend(pids)
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

        print('Extracting features from gallery set ...')
        gf, g_pids = [], []
        for batch_idx, (imgs, pids) in enumerate(galleryloader):
            imgs = imgs.cuda()
            end = time.time()
            features = model(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            gf.append(features)
            g_pids.extend(pids)
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

    print('Speed: {:.4f} sec/batch'.format(batch_time.avg))

    if normalize_feature:
        print('Normalzing features with L2 norm ...')
        qf = F.normalize(qf, p=2, dim=1)
        gf = F.normalize(gf, p=2, dim=1)

    print('Computing distance matrix with metric={} ...'.format(dist_metric))
    distmat = compute_distance_matrix(qf, gf, dist_metric)
    distmat = distmat.numpy()

    print('Computing rank1 and mAP ...')
    rank1, mAP, result = eval_rank(distmat, q_pids, g_pids)
    print('** Results **')
    print('Rank1: {:.8f}'.format(rank1))
    print('mAP: {:.8f}'.format(mAP))
    print('average: {:.8f}'.format(result))
Beispiel #8
0
def generate_person(person_features,
                    person_boxes,
                    face_features=None,
                    face_boxes=None,
                    face_effective=None,
                    out_face_threshold=opt.out_face_threshold,
                    face_threashold=opt.face_threshold,
                    metric=opt.face_metric):
    """
    根据得到的人脸和人的数据 生成 person 对象
    """

    person_current = [
        Person(person_features[i], person_boxes[i])
        for i in range(len(person_boxes))
    ]
    if face_effective:
        face_names = ['UnKnown' for _ in range(len(face_effective))]
        face_distances = [
            face_threashold + 0.1 for _ in range(len(face_effective))
        ]

        face_cost_matrix = compute_distance_matrix(face_features,
                                                   database_features,
                                                   metric=metric)
        face_matches = linear_assignment(face_cost_matrix)

        for i in range(len(face_matches)):
            a, b = face_matches[i]
            face_distances[a] = face_cost_matrix[a][b].item()
            if face_cost_matrix[a][b] < face_threashold:
                face_names[a] = database_labels[b]

        cost_matrix = 1 - person_face_cost_cpp(person_boxes, face_boxes)
        filter_line = filter_matches_between_people_and_face_frames(
            cost_matrix)
        matches = linear_assignment(cost_matrix)

        for i in range(len(matches)):
            a, b = matches[i]
            if a in filter_line:
                continue
            if cost_matrix[a][b] <= out_face_threshold and b in face_effective:
                effective_b = face_effective.index(b)
                person_current[a].fBox = face_boxes[b]
                person_current[a].fid = face_features[effective_b]
                person_current[a].fid_distance = face_distances[effective_b]
                person_current[a].name = face_names[effective_b]
    return person_current
Beispiel #9
0
def update_person(person_id,
                  person_current,
                  person_caches,
                  metric=opt.person_metric,
                  person_threshold=opt.person_threshold):
    """
    通过person reid 更新person
    """
    # cost_matrix = pw.pairwise_distances(combine_pid(person_cache), combine_pid(person_current))
    # 当cache不存在时
    if not person_caches:
        for person in person_current:
            person_id += 1
            person.id = person_id
            person_caches.append(Person_Cache(person))
        return person_current, person_caches, person_id

    else:
        cost_matrix = compute_distance_matrix(combine_cur_pid(person_current),
                                              combine_cache_pid(person_caches),
                                              metric=metric)
        cost_matrix = compress_cost_matrix(cost_matrix)
        matches = linear_assignment(cost_matrix)
        cur_person_dict_notFound = [i for i in range(len(person_current))]
        for i in range(len(matches)):
            a, b = matches[i]
            if cost_matrix[a][b] < person_threshold:
                cur_person_dict_notFound.remove(a)
                person_current[a].update_all(person_caches[b])
                if person_current[a].fid_distance <= person_caches[
                        b].fid_min_distance:
                    person_caches[b].name = person_current[a].name
                    person_caches[b].fid_min_distance = person_current[
                        a].fid_distance
                else:
                    person_current[a].name = person_caches[b].name
                person_current[a].fid_min_distance = person_caches[
                    b].fid_min_distance
                person_caches[b].update_all(person_current[a])

        # 没找到匹配时
        for i in cur_person_dict_notFound:
            person_id += 1
            person_current[i].id = person_id
            person_caches.append(Person_Cache(person_current[i]))

        return person_current, person_caches, person_id
Beispiel #10
0
def create_matrix(p1, p2, dist_metric="euclidean"):
    # f1 = []
    # for k in p1.keys():
    #     f1.extend(p1[k])
    # f2 = []
    # for k in p2.keys():
    #     f2.extend(p2[k])
    f1 = [p1[k] for k in p1.keys()]
    f2 = [p2[k] for k in p2.keys()]

    f1 = np.squeeze(np.concatenate(f1, axis=0), axis=1)
    f2 = np.squeeze(np.concatenate(f2, axis=0), axis=1)

    p1 = torch.tensor(f1)
    p2 = torch.tensor(f2)

    distances = metrics.compute_distance_matrix(p1, p2, dist_metric)
    return distances
Beispiel #11
0
def update_person(index,
                  person_cache: list,
                  cur_person_dict,
                  metric='euclidean',
                  person_threshold=100):
    # cost_matrix = pw.pairwise_distances(combine_pid(person_cache), combine_pid(cur_person_dict))
    # cost_matrix = distance(combine_pid(person_cache), combine_pid(cur_person_dict))

    # 当cache不存在时
    if not person_cache:
        for person in cur_person_dict:
            index += 1
            person.id = index
            person_cache.append(person)
        return person_cache, person_cache, index
    else:
        cost_matrix = compute_distance_matrix(combine_pid(cur_person_dict),
                                              combine_pid(person_cache),
                                              metric=metric)
        matches = linear_assignment(cost_matrix)
        new_person_cache = copy.deepcopy(person_cache)
        cur_person_dict_notFound = [i for i in range(len(cur_person_dict))]
        for i in range(len(matches)):
            a, b = matches[i]
            if cost_matrix[a][b] < person_threshold:
                cur_person_dict_notFound.remove(a)
                cur_person_dict[a].id = person_cache[b].id
                if cur_person_dict[a].name is "UnKnown" and person_cache[
                        b].name is not "UnKnown":
                    cur_person_dict[a].name = person_cache[b].name
                new_person_cache[b] = cur_person_dict[a]

        # 没找到匹配时
        for i in cur_person_dict_notFound:
            index += 1
            cur_person_dict[i].id = index
            new_person_cache.append(cur_person_dict[i])

        return new_person_cache, cur_person_dict, index
Beispiel #12
0
def evaluate(model,
             queryloader,
             galleryloader,
             dist_metric='euclidean',
             normalize_feature=False,
             rerank=False,
             return_distmat=False):
    batch_time = AverageMeter()
    model.eval()
    with torch.no_grad():
        print('Extracting features from query set ...')
        qf, q_pids, q_camids = [], [], []
        for batch_idx, (imgs, pids, camids, _) in enumerate(queryloader):
            imgs = imgs.cuda()
            end = time.time()
            features = model(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            qf.append(features)
            q_pids.extend(pids)
            q_camids.extend(camids)
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)
        print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

        print('Extracting features from gallery set ...')
        gf, g_pids, g_camids = [], [], []
        for batch_idx, (imgs, pids, camids, _) in enumerate(galleryloader):
            imgs = imgs.cuda()
            end = time.time()
            features = model(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            gf.append(features)
            g_pids.extend(pids)
            g_camids.extend(camids)
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)
        print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

    print('Speed: {:.4f} sec/batch'.format(batch_time.avg))

    if normalize_feature:
        print('Normalzing features with L2 norm ...')
        qf = F.normalize(qf, p=2, dim=1)
        gf = F.normalize(gf, p=2, dim=1)

    print('Computing distance matrix with metric={} ...'.format(dist_metric))
    distmat = compute_distance_matrix(qf, gf, dist_metric)
    distmat = distmat.numpy()

    if rerank:
        print('Applying person re-ranking ...')
        distmat_qq = compute_distance_matrix(qf, qf, dist_metric)
        distmat_gg = compute_distance_matrix(gf, gf, dist_metric)
        distmat = re_ranking(distmat, distmat_qq, distmat_gg)

    print('Computing CMC and mAP ...')
    cmc, mAP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids)
    print('** Results **')
    print('mAP: {:.1%}'.format(mAP))
    print('CMC curve')
    for r in [1, 5, 10, 20]:
        print('Rank-{:<3}: {:.1%}'.format(r, cmc[r - 1]))

    if return_distmat:
        return distmat

    return cmc[0]
Beispiel #13
0
    def _evaluate(self, arch, epoch, dataset_name='', queryloader=None, galleryloader=None,
                  dist_metric='euclidean', normalize_feature=False, visrank=False,
                  visrank_topk=20, save_dir='', use_metric_cuhk03=False, ranks=[1, 5, 10, 20],
                  rerank=False, viscam=False, viscam_num=10, viscam_only=False):
        with self.experiment.test():
            if not viscam_only:
                batch_time = AverageMeter()
                combine_time = AverageMeter()

                self.model.eval()

                print('Extracting features from query set ...')
                qf, q_pids, q_camids = [], [], [] # query features, query person IDs and query camera IDs
                for batch_idx, data in enumerate(queryloader):
                    imgs, pids, camids = self._parse_data_for_eval(data)
                    if self.use_gpu:
                        imgs = imgs.cuda()
                    end = time.time()
                    features = self._extract_features(imgs)
                    batch_time.update(time.time() - end, len(pids), True)
                    features = features.data.cpu()
                    qf.append(features)
                    q_pids.extend(pids)
                    q_camids.extend(camids)
                qf = torch.cat(qf, 0)
                q_pids = np.asarray(q_pids)
                q_camids = np.asarray(q_camids)
                print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

                print('Extracting features from gallery set ...')
                gf, g_pids, g_camids = [], [], [] # gallery features, gallery person IDs and gallery camera IDs
                end = time.time()
                for batch_idx, data in enumerate(galleryloader):
                    imgs, pids, camids = self._parse_data_for_eval(data)
                    if self.use_gpu:
                        imgs = imgs.cuda()
                    end = time.time()
                    features = self._extract_features(imgs)
                    batch_time.update(time.time() - end, len(pids), True)
                    features = features.data.cpu()
                    gf.append(features)
                    g_pids.extend(pids)
                    g_camids.extend(camids)
                gf = torch.cat(gf, 0)
                g_pids = np.asarray(g_pids)
                g_camids = np.asarray(g_camids)

                end = time.time()
                num_images = len(g_pids)
                self.combine_fn.train()
                gf, g_pids = self.combine_fn(gf, g_pids, g_camids)
                if self.save_embed:
                    assert osp.isdir(self.save_embed)
                    path = osp.realpath(self.save_embed)
                    np.save(path + '/gf-' + self.combine_method + '.npy', gf)
                    np.save(path + '/g_pids-' + self.combine_method + '.npy', g_pids)
                combine_time.update(time.time() - end, num_images, True)
                time.time() - end
                gf = torch.tensor(gf, dtype=torch.float)
                print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

                print('Speed: {:.4f} sec/image'.format(batch_time.avg + combine_time.avg))

                if normalize_feature:
                    print('Normalzing features with L2 norm ...')
                    qf = F.normalize(qf, p=2, dim=1)
                    gf = F.normalize(gf, p=2, dim=1)

                print('Computing distance matrix with metric={} ...'.format(dist_metric))
                distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
                distmat = distmat.numpy()

                if rerank:
                    print('Applying person re-ranking ...')
                    distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
                    distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
                    distmat = re_ranking(distmat, distmat_qq, distmat_gg)

                print('Computing CMC and mAP ...')
                cmc, mAP = metrics.evaluate_rank(
                    distmat,
                    q_pids,
                    g_pids,
                    q_camids,
                    g_camids,
                    use_metric_cuhk03=use_metric_cuhk03
                )

                print('** Results **')
                print('mAP: {:.1%}'.format(mAP))
                print('CMC curve')
                for r in ranks:
                    print('Rank-{:<3}: {:.1%}'.format(r, cmc[r-1]))

                # write to Tensorboard and comet.ml
                if not self.test_only:
                    rs = {'eval-rank-{:<3}'.format(r):cmc[r-1] for r in ranks}
                    self.writer.add_scalars('eval/ranks',rs,epoch)
                    self.experiment.log_metrics(rs,step=epoch)
                    self.writer.add_scalar('eval/mAP',mAP,epoch)
                    self.experiment.log_metric('eval-mAP',mAP,step=epoch)
                    print('Results written to tensorboard and comet.ml.')

            if visrank:
                visualize_ranked_results(
                    distmat,
                    self.datamanager.return_testdataset_by_name(dataset_name),
                    save_dir=osp.join(save_dir, 'visrank-'+str(epoch+1), dataset_name),
                    topk=visrank_topk
                )

            if viscam:
                if arch == 'osnet_x1_0' or arch == 'osnet_custom':
                    # print(self.model)
                    visualize_cam(
                        model=self.model,
                        finalconv='conv5',  # for OSNet
                        dataset=self.datamanager.return_testdataset_by_name(dataset_name),
                        save_dir=osp.join(save_dir, 'viscam-'+str(epoch+1), dataset_name),
                        num=viscam_num
                    )
                elif arch == 'resnext50_32x4d':
                    # print(self.model)
                    visualize_cam(
                        model=self.model,
                        finalconv='layer4',  # for resnext50
                        dataset=self.datamanager.return_testdataset_by_name(dataset_name),
                        save_dir=osp.join(save_dir, 'viscam-'+str(epoch+1), dataset_name),
                        num=viscam_num
                    )
                    
        if viscam_only:
            raise RuntimeError('Stop exec because `viscam_only` is set to true.')

        return cmc[0]
    def train(
            self,
            epoch,
            max_epoch,
            writer,
            print_freq=10,
            fixbase_epoch=0,
            open_layers=None,
    ):
        losses_triplet = AverageMeter()
        losses_softmax = AverageMeter()
        losses_mmd_bc = AverageMeter()
        losses_mmd_wc = AverageMeter()
        losses_mmd_global = AverageMeter()
        losses_recons = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        
        

        self.model.train()
        self.mgn_targetPredict.train()
       
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print(
                '* Only train {} (epoch: {}/{})'.format(
                    open_layers, epoch + 1, fixbase_epoch
                )
            )
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)
            open_all_layers(self.mgn_targetPredict)
            print("All open layers!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

        num_batches = len(self.train_loader)
        end = time.time()
       
# -------------------------------------------------------------------------------------------------------------------- #
        for batch_idx, (data, data_t) in enumerate(zip(self.train_loader, self.train_loader_t)):
            data_time.update(time.time() - end)
            

            imgs, pids = self._parse_data_for_train(data)
            imgs_clean =  imgs.clone().cuda()
            lam=0
            imgs_t, pids_t = self._parse_data_for_train(data_t)
            imagest_orig=imgs_t.cuda()
            labels=[]
            labelss=[]
            random_indexS = np.random.randint(0, imgs.size()[0])
            random_indexT = np.random.randint(0, imgs_t.size()[0])
            if epoch > 10 and epoch < 35:
                
                for i, img in enumerate(imgs):
                  
                   randmt = RandomErasing(probability=0.5,sl=0.07, sh=0.22)
                  
                   imgs[i],p = randmt(img, imgs[random_indexS])
                   labelss.append(p)
               
            if epoch >= 35:
                randmt = RandomErasing(probability=0.5,sl=0.1, sh=0.25)
                for i, img in enumerate(imgs):
                  
                   imgs[i],p = randmt(img,imgs[random_indexS])
                   labelss.append(p)

            





            
            if epoch > 10 and epoch < 35:
                randmt = RandomErasing(probability=0.5,sl=0.1, sh=0.2)
                for i, img in enumerate(imgs_t):
                   
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)
               
            if epoch >= 35 and epoch < 75:
                randmt = RandomErasing(probability=0.5,sl=0.2, sh=0.3)
                for i, img in enumerate(imgs_t):
                  
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)

            if epoch >= 75:
                randmt = RandomErasing(probability=0.5,sl=0.2, sh=0.35)
                for i, img in enumerate(imgs_t):
                   
                  
                   imgs_t[i],p = randmt(img,imgs_t[random_indexT])
                   labels.append(p)
           
            binary_labels = torch.tensor(np.asarray(labels)).cuda()
            binary_labelss = torch.tensor(np.asarray(labelss)).cuda()
            
               
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()
            if self.use_gpu:
                imgs_transformed = imgs_t.cuda()

            

            self.optimizer.zero_grad()
           
            imgs_clean = imgs
            outputs, output2, recons,bcc1, bocc2,bocc3 = self.model(imgs)

            occ_losss1 = self.BCE_criterion(bcc1.squeeze(1),binary_labelss.float() )
            occ_losss2 = self.BCE_criterion(bocc2.squeeze(1),binary_labelss.float() )
            occ_losss3 = self.BCE_criterion(bocc3.squeeze(1),binary_labelss.float() )

            occ_s  = occ_losss1  +occ_losss2+occ_losss3
       
           

          

            ##############CUT MIX#################################3333
            """bbx1, bby1, bbx2, bby2 = self.rand_bbox(imgs.size(), lam)
            rand_index = torch.randperm(imgs.size()[0]).cuda()
            imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2, bby1:bby2]
            targeta = pids
            targetb = pids[rand_index]"""

            ##############CUT MIX#################################3333

            outputs_t, output2_t, recons_t,bocct1, bocct2,bocct3 = self.model(imagest_orig)
            outputs_t = self.mgn_targetPredict(output2_t)
           


            loss_reconst=self.criterion_mse(recons_t, imagest_orig)
            loss_recons=self.criterion_mse(recons, imgs_clean)

         
            occ_loss1 = self.BCE_criterion(bocct1.squeeze(1),binary_labels.float() )
            occ_loss2 = self.BCE_criterion(bocct2.squeeze(1),binary_labels.float() )
            occ_loss3 = self.BCE_criterion(bocct3.squeeze(1),binary_labels.float() )
            occ_t = occ_loss1 + occ_loss2 + occ_loss3
            pids_t = pids_t.cuda()
            loss_x = self.mgn_loss(outputs, pids)
            loss_x_t = self.mgn_loss(outputs_t, pids_t)
            #loss_x_t = self._compute_loss(self.criterion_x, y, targeta)  #*lam + self._compute_loss(self.criterion_x, y, targetb)*(1-lam)
            #loss_t_t = self._compute_loss(self.criterion_t, features_t, targeta)*lam + self._compute_loss(self.criterion_t, features_t, targetb)*(1-lam)
                      
         
            if epoch > 10:

                loss_mmd_wc, loss_mmd_bc, loss_mmd_global = self._compute_loss(self.criterion_mmd, outputs[0],  outputs_t[0])
                #loss_mmd_wc1, loss_mmd_bc1, loss_mmd_global1  = self._compute_loss(self.criterion_mmd, outputs[2], outputs_t[2])
                #loss_mmd_wc3, loss_mmd_bc3, loss_mmd_global3  = self._compute_loss(self.criterion_mmd, outputs[3], outputs_t[3])
                
                #loss_mmd_wcf  = loss_mmd_wc+loss_mmd_wc1+loss_mmd_wc3
                #loss_mmd_bcf  = loss_mmd_bc+loss_mmd_bc1+loss_mmd_bc3
                #loss_mmd_globalf  = loss_mmd_global+loss_mmd_global1+loss_mmd_global3
                

                
                #print(loss_mmd_bc.item())

                l_joint =  1.5*loss_x_t  +loss_x +loss_reconst+loss_recons  #self.weight_r*loss_recons+ + loss_x + loss_t 
                #loss = loss_t + loss_x + loss_mmd_bc + loss_mmd_wc
                l_d =   0.5*loss_mmd_bc + 0.8*loss_mmd_wc    +loss_mmd_global #+loss_mmd_bc1 + loss_mmd_wc1    +loss_mmd_global1 +loss_mmd_bc3 + loss_mmd_wc3   +loss_mmd_global3
                loss =  0.3*l_d + 0.7*l_joint +0.2*occ_t + 0.1*occ_s

                

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
# -------------------------------------------------------------------------------------------------------------------- #

            batch_time.update(time.time() - end)
            #losses_triplet.update(loss_t.item(), pids.size(0))
            losses_softmax.update(loss_x_t.item(), pids.size(0))
            #losses_recons.update(loss_recons.item(), pids.size(0))
            if epoch > 10:
                losses_mmd_bc.update(loss_mmd_bc.item(), pids.size(0))
                losses_mmd_wc.update(loss_mmd_wc.item(), pids.size(0))
                losses_mmd_global.update(loss_mmd_global.item(), pids.size(0))

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                        num_batches - (batch_idx + 1) + (max_epoch -
                                                         (epoch + 1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    #'Loss_t {losses1.val:.4f} ({losses1.avg:.4f})\t'
                    'Loss_x {losses2.val:.4f} ({losses2.avg:.4f})\t'
                    'Loss_mmd_wc {losses3.val:.4f} ({losses3.avg:.4f})\t'
                    'Loss_mmd_bc {losses4.val:.4f} ({losses4.avg:.4f})\t'
                    'Loss_mmd_global {losses5.val:.4f} ({losses5.avg:.4f})\t'
                    #'Loss_recons {losses6.val:.4f} ({losses6.avg:.4f})\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        #losses1=losses_triplet,
                        losses2=losses_softmax,
                        losses3=losses_mmd_wc,
                        losses4=losses_mmd_bc,
                        losses5=losses_mmd_global,
                        #losses6 = losses_recons,
                        eta=eta_str
                    )
                )
            writer = None
            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Loss_triplet', losses_triplet.avg, n_iter)
                writer.add_scalar('Train/Loss_softmax', losses_softmax.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_bc', losses_mmd_bc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_wc', losses_mmd_wc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_global', losses_mmd_global.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()
        print_distri = True

        if print_distri:

            instances = self.datamanager.test_loader.query_loader.num_instances
            batch_size = self.datamanager.test_loader.batch_size
            feature_size = outputs[0].size(1) # features_t.shape[1]  # 2048
            features_t = outputs_t[0]
            features = outputs[0]
            t = torch.reshape(features_t, (int(batch_size / instances), instances, feature_size))
 
            #  and compute bc/wc euclidean distance
            bct = compute_distance_matrix(t[0], t[0])
            wct = compute_distance_matrix(t[0], t[1])
            for i in t[1:]:
                bct = torch.cat((bct, compute_distance_matrix(i, i)))
                for j in t:
                    if j is not i:
                        wct = torch.cat((wct, compute_distance_matrix(i, j)))

            s = torch.reshape(features, (int(batch_size / instances), instances, feature_size))
            bcs = compute_distance_matrix(s[0], s[0])
            wcs = compute_distance_matrix(s[0], s[1])
            for i in s[1:]:
                bcs = torch.cat((bcs, compute_distance_matrix(i, i)))
                for j in s:
                    if j is not i:
                        wcs = torch.cat((wcs, compute_distance_matrix(i, j)))

            bcs = bcs.detach()
            wcs = wcs.detach()

            b_c = [x.cpu().detach().item() for x in bcs.flatten() if x > 0.000001]
            w_c = [x.cpu().detach().item() for x in wcs.flatten() if x > 0.000001]
            data_bc = norm.rvs(b_c)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_c)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequency')
            plt.title('Source Domain')
            plt.legend()
            plt.savefig("Source.png")
            plt.clf()
            b_ct = [x.cpu().detach().item() for x in bct.flatten() if x > 0.1]
            w_ct = [x.cpu().detach().item() for x in wct.flatten() if x > 0.1]
            data_bc = norm.rvs(b_ct)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_ct)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequency')
            plt.title('Target Domain')
            plt.legend()
            plt.savefig("Target.png")
Beispiel #15
0
    def _evaluate(self,
                  epoch,
                  dataset_name='',
                  queryloader=None,
                  galleryloader=None,
                  dist_metric='euclidean',
                  normalize_feature=False,
                  visrank=False,
                  visrankactiv=False,
                  visrank_topk=10,
                  save_dir='',
                  use_metric_cuhk03=False,
                  ranks=[1, 5, 10, 20],
                  rerank=False,
                  visrankactivthr=False,
                  maskthr=0.7,
                  visdrop=False,
                  visdroptype='random'):
        batch_time = AverageMeter()

        print('Extracting features from query set ...')
        qf, qa, q_pids, q_camids, qm = [], [], [], [], [
        ]  # query features, query activations, query person IDs, query camera IDs and image drop masks
        for _, data in enumerate(queryloader):
            imgs, pids, camids = self._parse_data_for_eval(data)
            if self.use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            features = self._extract_features(imgs)
            activations = self._extract_activations(imgs)
            dropmask = self._extract_drop_masks(imgs, visdrop, visdroptype)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            qf.append(features)
            qa.append(torch.Tensor(activations))
            qm.append(torch.Tensor(dropmask))
            q_pids.extend(pids)
            q_camids.extend(camids)
        qf = torch.cat(qf, 0)
        qm = torch.cat(qm, 0)
        qa = torch.cat(qa, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)
        print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

        print('Extracting features from gallery set ...')
        gf, ga, g_pids, g_camids, gm = [], [], [], [], [
        ]  # gallery features, gallery activations,  gallery person IDs, gallery camera IDs and image drop masks
        end = time.time()
        for _, data in enumerate(galleryloader):
            imgs, pids, camids = self._parse_data_for_eval(data)
            if self.use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            features = self._extract_features(imgs)
            activations = self._extract_activations(imgs)
            dropmask = self._extract_drop_masks(imgs, visdrop, visdroptype)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            gf.append(features)
            ga.append(torch.Tensor(activations))
            gm.append(torch.Tensor(dropmask))
            g_pids.extend(pids)
            g_camids.extend(camids)
        gf = torch.cat(gf, 0)
        gm = torch.cat(gm, 0)
        ga = torch.cat(ga, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)
        print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

        print('Speed: {:.4f} sec/batch'.format(batch_time.avg))

        if normalize_feature:
            print('Normalzing features with L2 norm ...')
            qf = F.normalize(qf, p=2, dim=1)
            gf = F.normalize(gf, p=2, dim=1)

        print(
            'Computing distance matrix with metric={} ...'.format(dist_metric))
        distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
        distmat = distmat.numpy()

        #always show results without re-ranking first
        print('Computing CMC and mAP ...')
        cmc, mAP = metrics.evaluate_rank(distmat,
                                         q_pids,
                                         g_pids,
                                         q_camids,
                                         g_camids,
                                         use_metric_cuhk03=use_metric_cuhk03)

        print('** Results **')
        print('mAP: {:.1%}'.format(mAP))
        print('CMC curve')
        for r in ranks:
            print('Rank-{:<3}: {:.1%}'.format(r, cmc[r - 1]))

        if rerank:
            print('Applying person re-ranking ...')
            distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
            distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
            distmat = re_ranking(distmat, distmat_qq, distmat_gg)
            print('Computing CMC and mAP ...')
            cmc, mAP = metrics.evaluate_rank(
                distmat,
                q_pids,
                g_pids,
                q_camids,
                g_camids,
                use_metric_cuhk03=use_metric_cuhk03)

            print('** Results with Re-Ranking**')
            print('mAP: {:.1%}'.format(mAP))
            print('CMC curve')
            for r in ranks:
                print('Rank-{:<3}: {:.1%}'.format(r, cmc[r - 1]))

        if visrank:
            visualize_ranked_results(
                distmat,
                self.datamanager.return_testdataset_by_name(dataset_name),
                self.datamanager.data_type,
                width=self.datamanager.width,
                height=self.datamanager.height,
                save_dir=osp.join(save_dir, 'visrank_' + dataset_name),
                topk=visrank_topk)
        if visrankactiv:
            visualize_ranked_activation_results(
                distmat,
                qa,
                ga,
                self.datamanager.return_testdataset_by_name(dataset_name),
                self.datamanager.data_type,
                width=self.datamanager.width,
                height=self.datamanager.height,
                save_dir=osp.join(save_dir, 'visrankactiv_' + dataset_name),
                topk=visrank_topk)
        if visrankactivthr:
            visualize_ranked_threshold_activation_results(
                distmat,
                qa,
                ga,
                self.datamanager.return_testdataset_by_name(dataset_name),
                self.datamanager.data_type,
                width=self.datamanager.width,
                height=self.datamanager.height,
                save_dir=osp.join(save_dir, 'visrankactivthr_' + dataset_name),
                topk=visrank_topk,
                threshold=maskthr)
        if visdrop:
            visualize_ranked_mask_activation_results(
                distmat,
                qa,
                ga,
                qm,
                gm,
                self.datamanager.return_testdataset_by_name(dataset_name),
                self.datamanager.data_type,
                width=self.datamanager.width,
                height=self.datamanager.height,
                save_dir=osp.join(
                    save_dir, 'visdrop_{}_{}'.format(visdroptype,
                                                     dataset_name)),
                topk=visrank_topk)

        return cmc[0]
Beispiel #16
0
    def train(
        self,
        epoch,
        max_epoch,
        writer,
        print_freq=1,
        fixbase_epoch=0,
        open_layers=None,
    ):
        losses_triplet = AverageMeter()
        losses_softmax = AverageMeter()
        losses_recons_s = AverageMeter()
        losses_recons_t = AverageMeter()
        losses_mmd_bc = AverageMeter()
        losses_mmd_wc = AverageMeter()
        losses_mmd_global = AverageMeter()
        losses_local = AverageMeter()

        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print('* Only train {} (epoch: {}/{})'.format(
                open_layers, epoch + 1, fixbase_epoch))
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()
        weight_r = self.weight_r
        # -------------------------------------------------------------------------------------------------------------------- #
        for batch_idx, (data, data_t) in enumerate(
                zip(self.train_loader, self.train_loader_t)):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            imgs_t, pids_t = self._parse_data_for_train(data_t)
            if self.use_gpu:
                imgs_t = imgs_t.cuda()

            self.optimizer.zero_grad()
            noisy_imgs = self.random(imgs)
            outputs, part_outs, features, recons, z, mean, var, local_feat = self.model(
                noisy_imgs)
            parts_loss = 0

            for i in range(len(part_outs)):
                out = part_outs[i]

                parts_loss += self._compute_loss(
                    self.criterion_x, out, pids)  #  self.criterion( out, pids)

            parts_loss = parts_loss / len(part_outs)
            #print("local feats")
            #print(local_feat.shape)
            #print("global feats ")
            #print(local_feat.reshape(local_feat.size(0),-1).t().shape)

            imgs_t = self.random2(imgs_t)
            outputs_t, parts_out_t, features_t, recons_t, z_t, mean_t, var_t, local_feat_t = self.model(
                imgs_t)

            loss_t = self._compute_loss(self.criterion_t, features, pids)
            loss_x = self._compute_loss(self.criterion_x, outputs, pids)
            loss_r1 = self.loss_vae(imgs, recons, mean, var)
            loss_r2 = self.loss_vae(imgs_t, recons_t, mean_t, var_t)

            dist_mat_s = self.get_local_correl(local_feat)
            dist_mat_t = self.get_local_correl(local_feat_t)

            dist_mat_s = dist_mat_s.detach()
            local_loss = self.criterion_mmd.mmd_rbf_noaccelerate(
                dist_mat_s, dist_mat_t)

            kl_loss = torch.tensor(0)
            #loss = loss_t + loss_x + weight_r*loss_r1 +  (weight_r*2)*loss_r2 + loss_mmd_global #+ 0.1*kl_loss
            loss_mmd_wc, loss_mmd_bc, loss_mmd_global = self._compute_loss(
                self.criterion_mmd, features, features_t)
            loss = loss_t + loss_x + weight_r * loss_r1 + 0 * loss_r2 + loss_mmd_wc + loss_mmd_bc + loss_mmd_global + parts_loss  #weight_r2 =0 is best
            if epoch > 10:

                #loss = loss_t + loss_x  + weight_r*loss_r1  + (weight_r)*loss_r2  +  loss_mmd_wc + loss_mmd_bc  + loss_mmd_global

                if False:
                    loss_mmd_bc = torch.tensor(0)
                    loss_mmd_global = torch.tensor(0)
                    loss_mmd_wc = torch.tensor(0)
                    kl_loss = torch.tensor(0)

                    #loss = loss_mmd_bc + loss_mmd_wc
                    loss = loss_t + loss_x + weight_r * loss_r1 + (
                        weight_r
                    ) * loss_r2 + loss_mmd_wc + loss_mmd_bc + loss_mmd_global

            loss.backward()
            self.optimizer.step()
            # -------------------------------------------------------------------------------------------------------------------- #

            batch_time.update(time.time() - end)
            losses_triplet.update(loss_t.item(), pids.size(0))
            losses_softmax.update(loss_x.item(), pids.size(0))
            losses_recons_s.update(loss_r1.item(), pids.size(0))
            losses_recons_t.update(loss_r2.item(), pids.size(0))

            losses_local.update(local_loss.item(), pids.size(0))

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (num_batches - (batch_idx + 1) +
                                                (max_epoch -
                                                 (epoch + 1)) * num_batches)
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss_t {losses1.val:.4f} ({losses1.avg:.4f})\t'
                      'Loss_x {losses2.val:.4f} ({losses2.avg:.4f})\t'
                      'Loss_reconsS {losses4.val:.4f} ({losses4.avg:.4f})\t'
                      'Loss_reconsT {losses5.val:.4f} ({losses5.avg:.4f})\t'
                      'Loss_local {losses6.val:.4f} ({losses6.avg:.4f})\t'
                      'eta {eta}'.format(epoch + 1,
                                         max_epoch,
                                         batch_idx + 1,
                                         num_batches,
                                         batch_time=batch_time,
                                         losses1=losses_triplet,
                                         losses2=losses_softmax,
                                         losses4=losses_recons_s,
                                         losses5=losses_recons_t,
                                         losses6=losses_local,
                                         eta=eta_str))

            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Loss_triplet', losses_triplet.avg,
                                  n_iter)
                writer.add_scalar('Train/Loss_softmax', losses_softmax.avg,
                                  n_iter)

                writer.add_scalar('Train/Loss_recons_s', losses_recons_s.avg,
                                  n_iter)
                writer.add_scalar('Train/Loss_recons_t', losses_recons_t.avg,
                                  n_iter)

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()

        print_distri = False

        if print_distri:
            print("Printing distribution")
            instances = self.datamanager.train_loader.sampler.num_instances
            batch_size = self.datamanager.train_loader.batch_size
            feature_size = 1024  # features_t.shape[1]  # 2048
            #print("local feature size!!!")
            #print(local_feat_t.shape)
            local_feat_t = local_feat_t.reshape(local_feat_t.size(0), -1)
            t = torch.reshape(
                local_feat_t,
                (int(batch_size / instances), instances, feature_size))

            #  and compute bc/wc euclidean distance
            bct = compute_distance_matrix(t[0], t[0])
            wct = compute_distance_matrix(t[0], t[1])
            for i in t[1:]:
                bct = torch.cat((bct, compute_distance_matrix(i, i)))
                for j in t:
                    if j is not i:
                        wct = torch.cat((wct, compute_distance_matrix(i, j)))

            s = torch.reshape(
                local_feat,
                (int(batch_size / instances), instances, feature_size))
            bcs = compute_distance_matrix(s[0], s[0])
            wcs = compute_distance_matrix(s[0], s[1])
            for i in s[1:]:
                bcs = torch.cat((bcs, compute_distance_matrix(i, i)))
                for j in s:
                    if j is not i:
                        wcs = torch.cat((wcs, compute_distance_matrix(i, j)))

            bcs = bcs.detach()
            wcs = wcs.detach()

            b_c = [
                x.cpu().detach().item() for x in bcs.flatten() if x > 0.000001
            ]
            w_c = [
                x.cpu().detach().item() for x in wcs.flatten() if x > 0.000001
            ]
            data_bc = norm.rvs(b_c)
            sns.distplot(data_bc,
                         bins='auto',
                         fit=norm,
                         kde=False,
                         label='from the same class (within class)')
            data_wc = norm.rvs(w_c)
            sns.distplot(data_wc,
                         bins='auto',
                         fit=norm,
                         kde=False,
                         label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequence of Occurance')
            plt.title('Source Domain')
            plt.legend()
            plt.savefig(
                "/export/livia/home/vision/mkiran/work/Person_Reid/Video_Person/Domain_Adapt/D-MMD/figs/Non_Occluded_distribution.png"
            )
            plt.clf()

            b_ct = [x.cpu().detach().item() for x in bct.flatten() if x > 0.1]
            w_ct = [x.cpu().detach().item() for x in wct.flatten() if x > 0.1]
            data_bc = norm.rvs(b_ct)
            sns.distplot(data_bc,
                         bins='auto',
                         fit=norm,
                         kde=False,
                         label='from the same class (within class)')
            data_wc = norm.rvs(w_ct)
            sns.distplot(data_wc,
                         bins='auto',
                         fit=norm,
                         kde=False,
                         label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequence of apparition')
            plt.title('Non-Occluded Data Domain')
            plt.legend()
            plt.savefig(
                "/export/livia/home/vision/mkiran/work/Person_Reid/Video_Person/Domain_Adapt/D-MMD/figs/Occluded_distribution.png"
            )
            plt.clf()
    def _evaluate_reid(self,
                       model,
                       epoch,
                       dataset_name='',
                       query_loader=None,
                       gallery_loader=None,
                       dist_metric='euclidean',
                       normalize_feature=False,
                       visrank=False,
                       visrank_topk=10,
                       save_dir='',
                       use_metric_cuhk03=False,
                       ranks=(1, 5, 10, 20),
                       rerank=False,
                       model_name='',
                       lr_finder=False):
        def _feature_extraction(data_loader):
            f_, pids_, camids_ = [], [], []
            for _, data in enumerate(data_loader):
                imgs, pids, camids = self.parse_data_for_eval(data)
                if self.use_gpu:
                    imgs = imgs.cuda()

                features = model(imgs),
                features = features.data.cpu()

                f_.append(features)
                pids_.extend(pids)
                camids_.extend(camids)

            f_ = torch.cat(f_, 0)
            pids_ = np.asarray(pids_)
            camids_ = np.asarray(camids_)

            return f_, pids_, camids_

        qf, q_pids, q_camids = _feature_extraction(query_loader)
        gf, g_pids, g_camids = _feature_extraction(gallery_loader)

        if normalize_feature:
            qf = F.normalize(qf, p=2, dim=1)
            gf = F.normalize(gf, p=2, dim=1)

        distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
        distmat = distmat.numpy()

        if rerank:
            distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
            distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
            distmat = re_ranking(distmat, distmat_qq, distmat_gg)

        cmc, mAP = metrics.evaluate_rank(distmat,
                                         q_pids,
                                         g_pids,
                                         q_camids,
                                         g_camids,
                                         use_metric_cuhk03=use_metric_cuhk03)

        if self.writer is not None and not lr_finder:
            self.writer.add_scalar(
                'Val/{}/{}/mAP'.format(dataset_name, model_name), mAP,
                epoch + 1)
            for r in ranks:
                self.writer.add_scalar(
                    'Val/{}/{}/Rank-{}'.format(dataset_name, model_name, r),
                    cmc[r - 1], epoch + 1)
        if not lr_finder:
            print('** Results ({}) **'.format(model_name))
            print('mAP: {:.2%}'.format(mAP))
            print('CMC curve')
            for r in ranks:
                print('Rank-{:<3}: {:.2%}'.format(r, cmc[r - 1]))

        if visrank and not lr_finder:
            visualize_ranked_results(
                distmat,
                self.datamanager.fetch_test_loaders(dataset_name),
                self.datamanager.data_type,
                width=self.datamanager.width,
                height=self.datamanager.height,
                save_dir=osp.join(save_dir, 'visrank_' + dataset_name),
                topk=visrank_topk)

        return cmc[0]
Beispiel #18
0
    def _evaluate(self,
                  epoch,
                  dataset_name='',
                  queryloader=None,
                  galleryloader=None,
                  dist_metric='euclidean',
                  visrank=False,
                  visrank_topk=20,
                  save_dir='',
                  use_metric_cuhk03=False,
                  ranks=[1, 5, 10, 20]):
        batch_time = AverageMeter()

        self.model.eval()

        print('Extracting features from query set ...')
        qf, q_pids, q_camids = [], [], []
        for batch_idx, data in enumerate(queryloader):
            imgs, pids, camids = self._parse_data_for_eval(data)
            if self.use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            features = self._extract_features(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            qf.append(features)
            q_pids.extend(pids)
            q_camids.extend(camids)
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)
        print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

        print('Extracting features from gallery set ...')
        gf, g_pids, g_camids = [], [], []
        end = time.time()
        for batch_idx, data in enumerate(galleryloader):
            imgs, pids, camids = self._parse_data_for_eval(data)
            if self.use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            features = self._extract_features(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            gf.append(features)
            g_pids.extend(pids)
            g_camids.extend(camids)
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)
        print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

        print('Speed: {:.4f} sec/batch'.format(batch_time.avg))

        distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
        distmat = distmat.numpy()

        print('Computing CMC and mAP ...')
        cmc, mAP = metrics.evaluate_rank(distmat,
                                         q_pids,
                                         g_pids,
                                         q_camids,
                                         g_camids,
                                         use_metric_cuhk03=use_metric_cuhk03)

        print('** Results **')
        print('mAP: {:.1%}'.format(mAP))
        print('CMC curve')
        for r in ranks:
            print('Rank-{:<3}: {:.1%}'.format(r, cmc[r - 1]))

        if visrank:
            visualize_ranked_results(
                distmat,
                self.datamanager.return_testdataset_by_name(dataset_name),
                save_dir=osp.join(save_dir, 'visrank-' + str(epoch + 1),
                                  dataset_name),
                topk=visrank_topk)

        return cmc[0]
Beispiel #19
0
    def train(
            self,
            epoch,
            max_epoch,
            writer,
            print_freq=10,
            fixbase_epoch=0,
            open_layers=None,
    ):
        losses_triplet = AverageMeter()
        losses_softmax = AverageMeter()
        losses_mmd_bc = AverageMeter()
        losses_mmd_wc = AverageMeter()
        losses_mmd_global = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()

        self.model.train()
        if (epoch + 1) <= fixbase_epoch and open_layers is not None:
            print(
                '* Only train {} (epoch: {}/{})'.format(
                    open_layers, epoch + 1, fixbase_epoch
                )
            )
            open_specified_layers(self.model, open_layers)
        else:
            open_all_layers(self.model)

        num_batches = len(self.train_loader)
        end = time.time()

# -------------------------------------------------------------------------------------------------------------------- #
        for batch_idx, (data, data_t) in enumerate(zip(self.train_loader, self.train_loader_t)):
            data_time.update(time.time() - end)

            imgs, pids = self._parse_data_for_train(data)
            if self.use_gpu:
                imgs = imgs.cuda()
                pids = pids.cuda()

            imgs_t, pids_t = self._parse_data_for_train(data_t)
            if self.use_gpu:
                imgs_t = imgs_t.cuda()

            self.optimizer.zero_grad()

            outputs, features = self.model(imgs)
            outputs_t, features_t = self.model(imgs_t)

            loss_t = self._compute_loss(self.criterion_t, features, pids)
            loss_x = self._compute_loss(self.criterion_x, outputs, pids)
            loss = loss_t + loss_x

            if epoch > 20:
                loss_mmd_wc, loss_mmd_bc, loss_mmd_global = self._compute_loss(self.criterion_mmd, features, features_t)
                #loss = loss_t + loss_x + loss_mmd_bc + loss_mmd_wc
                loss = loss_t + loss_x + loss_mmd_global + loss_mmd_bc + loss_mmd_wc

                if False:
                    loss_t = torch.tensor(0)
                    loss_x = torch.tensor(0)
                    #loss = loss_mmd_bc + loss_mmd_wc
                    loss = loss_mmd_bc + loss_mmd_wc + loss_mmd_global


            loss.backward()
            self.optimizer.step()
# -------------------------------------------------------------------------------------------------------------------- #

            batch_time.update(time.time() - end)
            losses_triplet.update(loss_t.item(), pids.size(0))
            losses_softmax.update(loss_x.item(), pids.size(0))
            if epoch > 24:
                losses_mmd_bc.update(loss_mmd_bc.item(), pids.size(0))
                losses_mmd_wc.update(loss_mmd_wc.item(), pids.size(0))
                losses_mmd_global.update(loss_mmd_global.item(), pids.size(0))

            if (batch_idx + 1) % print_freq == 0:
                # estimate remaining time
                eta_seconds = batch_time.avg * (
                        num_batches - (batch_idx + 1) + (max_epoch -
                                                         (epoch + 1)) * num_batches
                )
                eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
                print(
                    'Epoch: [{0}/{1}][{2}/{3}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss_t {losses1.val:.4f} ({losses1.avg:.4f})\t'
                    'Loss_x {losses2.val:.4f} ({losses2.avg:.4f})\t'
                    'Loss_mmd_wc {losses3.val:.4f} ({losses3.avg:.4f})\t'
                    'Loss_mmd_bc {losses4.val:.4f} ({losses4.avg:.4f})\t'
                    'Loss_mmd_global {losses5.val:.4f} ({losses5.avg:.4f})\t'
                    'eta {eta}'.format(
                        epoch + 1,
                        max_epoch,
                        batch_idx + 1,
                        num_batches,
                        batch_time=batch_time,
                        losses1=losses_triplet,
                        losses2=losses_softmax,
                        losses3=losses_mmd_wc,
                        losses4=losses_mmd_bc,
                        losses5=losses_mmd_global,
                        eta=eta_str
                    )
                )

            if writer is not None:
                n_iter = epoch * num_batches + batch_idx
                writer.add_scalar('Train/Time', batch_time.avg, n_iter)
                writer.add_scalar('Train/Loss_triplet', losses_triplet.avg, n_iter)
                writer.add_scalar('Train/Loss_softmax', losses_softmax.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_bc', losses_mmd_bc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_wc', losses_mmd_wc.avg, n_iter)
                writer.add_scalar('Train/Loss_mmd_global', losses_mmd_global.avg, n_iter)
                writer.add_scalar(
                    'Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter
                )

            end = time.time()

        if self.scheduler is not None:
            self.scheduler.step()

        print_distri = False

        if print_distri:

            instances = self.datamanager.train_loader.sampler.num_instances
            batch_size = self.datamanager.train_loader.batch_size
            feature_size = 2048 # features_t.shape[1]  # 2048
            t = torch.reshape(features_t, (int(batch_size / instances), instances, feature_size))

            #  and compute bc/wc euclidean distance
            bct = compute_distance_matrix(t[0], t[0])
            wct = compute_distance_matrix(t[0], t[1])
            for i in t[1:]:
                bct = torch.cat((bct, compute_distance_matrix(i, i)))
                for j in t:
                    if j is not i:
                        wct = torch.cat((wct, compute_distance_matrix(i, j)))

            s = torch.reshape(features, (int(batch_size / instances), instances, feature_size))
            bcs = compute_distance_matrix(s[0], s[0])
            wcs = compute_distance_matrix(s[0], s[1])
            for i in s[1:]:
                bcs = torch.cat((bcs, compute_distance_matrix(i, i)))
                for j in s:
                    if j is not i:
                        wcs = torch.cat((wcs, compute_distance_matrix(i, j)))

            bcs = bcs.detach()
            wcs = wcs.detach()

            b_c = [x.cpu().detach().item() for x in bcs.flatten() if x > 0.000001]
            w_c = [x.cpu().detach().item() for x in wcs.flatten() if x > 0.000001]
            data_bc = norm.rvs(b_c)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_c)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequence of apparition')
            plt.title('Source Domain')
            plt.legend()
            plt.show()

            b_ct = [x.cpu().detach().item() for x in bct.flatten() if x > 0.1]
            w_ct = [x.cpu().detach().item() for x in wct.flatten() if x > 0.1]
            data_bc = norm.rvs(b_ct)
            sns.distplot(data_bc, bins='auto', fit=norm, kde=False, label='from the same class (within class)')
            data_wc = norm.rvs(w_ct)
            sns.distplot(data_wc, bins='auto', fit=norm, kde=False, label='from different class (between class)')
            plt.xlabel('Euclidean distance')
            plt.ylabel('Frequence of apparition')
            plt.title('Target Domain')
            plt.legend()
            plt.show()
Beispiel #20
0
def test(model,
         queryloader,
         galleryloader,
         pool,
         use_gpu,
         ranks=(1, 5, 10, 20),
         return_distmat=False):
    global mAP
    batch_time = AverageMeter()

    model.eval()

    with torch.no_grad():
        qf, q_pids, q_camids = [], [], []
        for batch_idx, (imgs, pids, camids, adj) in enumerate(queryloader):
            if use_gpu:
                imgs, adj = imgs.cuda(), adj.cuda()
            if args.test_sample in ['dense', 'skipdense']:
                b, n, s, c, h, w = imgs.size()
                imgs = imgs.view(b * n, s, c, h, w)
                adj = adj.view(b * n, adj.size(-1), adj.size(-1))
            else:
                n, s, c, h, w = imgs.size()

            end = time.time()
            features = model(imgs, adj)
            batch_time.update(time.time() - end)
            if args.test_sample in ['dense', 'skipdense']:
                features = features.view(n, 1, -1)
                if pool == 'avg':
                    features = torch.mean(features, 0)
                else:
                    features, _ = torch.max(features, 0)
            features = features.data.cpu()
            qf.append(features)
            q_pids.extend(pids.numpy())
            q_camids.extend(camids.numpy())
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)

        print("Extracted features for query set, obtained {}-by-{} matrix".
              format(qf.size(0), qf.size(1)))

        gf, g_pids, g_camids = [], [], []
        for batch_idx, (imgs, pids, camids, adj) in enumerate(galleryloader):
            if use_gpu:
                imgs, adj = imgs.cuda(), adj.cuda()
            if args.test_sample in ['dense', 'skipdense']:
                b, n, s, c, h, w = imgs.size()
                imgs = imgs.view(b * n, s, c, h, w)
                adj = adj.view(b * n, adj.size(-1), adj.size(-1))
            else:
                n, s, c, h, w = imgs.size()

            end = time.time()
            features = model(imgs, adj)
            batch_time.update(time.time() - end)
            if args.test_sample in ['dense', 'skipdense']:
                features = features.view(n, 1, -1)
                if pool == 'avg':
                    features = torch.mean(features, 0)
                else:
                    features, _ = torch.max(features, 0)
            features = features.data.cpu()
            gf.append(features)
            g_pids.extend(pids.numpy())
            g_camids.extend(camids.numpy())
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)

        print("Extracted features for gallery set, obtained {}-by-{} matrix".
              format(gf.size(0), gf.size(1)))

    print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(
        batch_time.avg, args.test_batch * args.seq_len))

    print('Computing distance matrix with metric={} ...'.format(
        args.dist_metric))
    distmat = metrics.compute_distance_matrix(qf, gf, args.dist_metric)
    distmat = distmat.numpy()

    if args.re_rank:
        print('Applying person re-ranking ...')
        distmat_qq = metrics.compute_distance_matrix(qf, qf, args.dist_metric)
        distmat_gg = metrics.compute_distance_matrix(gf, gf, args.dist_metric)
        distmat = re_ranking(distmat, distmat_qq, distmat_gg)

    print("Computing CMC and mAP")

    cmc, mAP = metrics.evaluate_rank(distmat,
                                     q_pids,
                                     g_pids,
                                     q_camids,
                                     g_camids,
                                     use_metric_mars=True)

    print("Results ----------")
    print("mAP: {:.2%}".format(mAP))
    print("CMC curve")
    for r in ranks:
        print("Rank-{:<3}: {:.2%}".format(r, cmc[r - 1]))
    print("------------------")

    if return_distmat:
        return distmat
    return cmc[0], mAP
Beispiel #21
0
    def _evaluate(
        self,
        dataset_name='',
        query_loader=None,
        gallery_loader=None,
        dist_metric='euclidean',
        normalize_feature=False,
        visrank=False,
        visrank_topk=10,
        save_dir='',
        use_metric_cuhk03=False,
        ranks=[1, 5, 10, 20],
        rerank=False
    ):
        batch_time = AverageMeter()

        def _feature_extraction(data_loader):
            f_, pids_, camids_ = [], [], []
            for batch_idx, data in enumerate(data_loader):
                imgs, pids, camids = self.parse_data_for_eval(data)
                if self.use_gpu:
                    imgs = imgs.cuda()
                end = time.time()
                features = self.extract_features(imgs)
                batch_time.update(time.time() - end)
                features = features.cpu().clone()
                f_.append(features)
                pids_.extend(pids)
                camids_.extend(camids)
            f_ = torch.cat(f_, 0)
            pids_ = np.asarray(pids_)
            camids_ = np.asarray(camids_)
            return f_, pids_, camids_

        print('Extracting features from query set ...')
        qf, q_pids, q_camids = _feature_extraction(query_loader)
        print(qf.shape)
        print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

        print('Extracting features from gallery set ...')
        gf, g_pids, g_camids = _feature_extraction(gallery_loader)
        print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

        print('Speed: {:.4f} sec/batch'.format(batch_time.avg))

        if normalize_feature:
            print('Normalzing features with L2 norm ...')
            qf = F.normalize(qf, p=2, dim=1)
            gf = F.normalize(gf, p=2, dim=1)

        print(
            'Computing distance matrix with metric={} ...'.format(dist_metric)
        )
        distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
        distmat = distmat.numpy()

        if rerank:
            print('Applying person re-ranking ...')
            distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
            distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
            distmat = re_ranking(distmat, distmat_qq, distmat_gg)

        print('Computing CMC and mAP ...')
        cmc, mAP = metrics.evaluate_rank(
            distmat,
            q_pids,
            g_pids,
            q_camids,
            g_camids,
            use_metric_cuhk03=use_metric_cuhk03
        )

        print('** Results **')
        print('mAP: {:.1%}'.format(mAP))
        print('CMC curve')
        for r in ranks:
            print('Rank-{:<3}: {:.1%}'.format(r, cmc[r - 1]))

        if visrank:
            visualize_ranked_results(
                distmat,
                self.datamanager.fetch_test_loaders(dataset_name),
                self.datamanager.data_type,
                width=self.datamanager.width,
                height=self.datamanager.height,
                save_dir=osp.join(save_dir, 'visrank_' + dataset_name),
                topk=visrank_topk
            )

        return cmc[0], mAP
 def compute_distance(self, qf, gf):
     distmat = metrics.compute_distance_matrix(qf, gf, self.dist_metric)
     # print(distmat.shape)
     return distmat.numpy()
Beispiel #23
0
    def _evaluate(self, epoch, dataset_name='', queryloader=None, galleryloader=None,
                  dist_metric='euclidean', normalize_feature=False, visrank=False,
                  visrank_topk=10, save_dir='', use_metric_cuhk03=False, ranks=(1, 5, 10, 20),
                  rerank=False, iteration=0):
        batch_time = AverageMeter()

        print('Extracting features from query set...')
        qf, q_pids, q_camids = [], [], []  # query features, query person IDs and query camera IDs
        for batch_idx, data in tqdm(enumerate(queryloader), 'Processing query...'):
            imgs, pids, camids = self._parse_data_for_eval(data)
            if self.use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            features = self._extract_features(imgs, data[3])
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            qf.append(features)
            q_pids.extend(pids)
            q_camids.extend(camids)
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)
        print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

        print('Extracting features from gallery set...')
        gf, g_pids, g_camids = [], [], []  # gallery features, gallery person IDs and gallery camera IDs
        for batch_idx, data in tqdm(enumerate(galleryloader), 'Processing gallery...'):
            imgs, pids, camids = self._parse_data_for_eval(data)
            if self.use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            features = self._extract_features(imgs, data[3])
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            gf.append(features)
            g_pids.extend(pids)
            g_camids.extend(camids)
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)
        print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

        print('Speed: {:.4f} sec/batch'.format(batch_time.avg))

        if normalize_feature:
            print('Normalizing features with L2 norm...')
            qf = F.normalize(qf, p=2, dim=1)
            gf = F.normalize(gf, p=2, dim=1)

        print('Computing distance matrix with metric={}...'.format(dist_metric))
        distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
        distmat = distmat.numpy()

        if rerank:
            print('Applying person re-ranking ...')
            distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
            distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
            distmat = re_ranking(distmat, distmat_qq, distmat_gg)

        print('Computing CMC and mAP ...')
        cmc, mAP = metrics.evaluate_rank(
            distmat,
            q_pids,
            g_pids,
            q_camids,
            g_camids,
            use_metric_cuhk03=use_metric_cuhk03
        )
        if self.writer is not None:
            self.writer.add_scalar('Val/{}/mAP'.format(dataset_name), mAP, epoch + 1)
            for r in ranks:
                self.writer.add_scalar('Val/{}/Rank-{}'.format(dataset_name, r), cmc[r - 1], epoch + 1)

        print('** Results **')
        print('mAP: {:.2%}'.format(mAP))
        print('CMC curve')
        for r in ranks:
            print('Rank-{:<3}: {:.2%}'.format(r, cmc[r-1]))

        if visrank:
            visualize_ranked_results(
                distmat,
                self.datamanager.return_testdataset_by_name(dataset_name),
                self.datamanager.data_type,
                width=self.datamanager.width,
                height=self.datamanager.height,
                save_dir=osp.join(save_dir, 'visrank_' + dataset_name),
                topk=visrank_topk
            )

        return cmc[0]
Beispiel #24
0
    def _evaluate(self,
                  epoch,
                  dataset_name='',
                  queryloader=None,
                  galleryloader=None,
                  dist_metric='euclidean',
                  normalize_feature=False,
                  visrank=False,
                  visrank_topk=20,
                  save_dir='',
                  use_metric_cuhk03=False,
                  ranks=[1, 5, 10, 20],
                  rerank=False,
                  load_pose=False,
                  part_score=False):
        batch_time = AverageMeter()

        self.model.eval()

        print('Extracting features from query set ...')
        qf, q_pids, q_camids = [], [], [
        ]  # query features, query person IDs and query camera IDs
        q_score = []
        for batch_idx, data in enumerate(queryloader):
            if load_pose:
                imgs, pids, camids, pose = self._parse_data_for_eval(data)
            else:
                imgs, pids, camids = self._parse_data_for_eval(data)
            if self.use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            if load_pose:
                if part_score:
                    features, score = self._extract_features(imgs, pose)
                    score = score.data.cpu()
                    q_score.append(score)
                else:
                    features = self._extract_features(imgs, pose)
            else:
                features = self._extract_features(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            qf.append(features)
            q_pids.extend(pids)
            q_camids.extend(camids)
        qf = torch.cat(qf, 0)
        q_pids = np.asarray(q_pids)
        q_camids = np.asarray(q_camids)
        if part_score:
            q_score = torch.cat(q_score)
        print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))

        print('Extracting features from gallery set ...')
        gf, g_pids, g_camids = [], [], [
        ]  # gallery features, gallery person IDs and gallery camera IDs
        g_score = []
        end = time.time()
        for batch_idx, data in enumerate(galleryloader):
            if load_pose:
                imgs, pids, camids, pose = self._parse_data_for_eval(data)
            else:
                imgs, pids, camids = self._parse_data_for_eval(data)
            if self.use_gpu:
                imgs = imgs.cuda()
            end = time.time()
            if load_pose:
                # if part_score:
                if part_score:
                    features, score = self._extract_features(imgs, pose)
                    score = score.data.cpu()
                    g_score.append(score)
                else:
                    features = self._extract_features(imgs, pose)
            else:
                features = self._extract_features(imgs)
            batch_time.update(time.time() - end)
            features = features.data.cpu()
            gf.append(features)
            g_pids.extend(pids)
            g_camids.extend(camids)
        gf = torch.cat(gf, 0)
        g_pids = np.asarray(g_pids)
        g_camids = np.asarray(g_camids)
        if part_score:
            g_score = torch.cat(g_score)
        print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))

        print('Speed: {:.4f} sec/batch'.format(batch_time.avg))

        if normalize_feature:
            print('Normalzing features with L2 norm ...')
            qf = F.normalize(qf, p=2, dim=1)
            gf = F.normalize(gf, p=2, dim=1)

        print(
            'Computing distance matrix with metric={} ...'.format(dist_metric))
        if part_score:
            distmat = metrics.compute_weight_distance_matrix(
                qf, gf, q_score, g_score, dist_metric)
        else:
            distmat = metrics.compute_distance_matrix(qf, gf, dist_metric)
        distmat = distmat.numpy()

        if rerank:
            print('Applying person re-ranking ...')
            distmat_qq = metrics.compute_distance_matrix(qf, qf, dist_metric)
            distmat_gg = metrics.compute_distance_matrix(gf, gf, dist_metric)
            distmat = re_ranking(distmat, distmat_qq, distmat_gg)

        print('Computing CMC and mAP ...')
        cmc, mAP = metrics.evaluate_rank(distmat,
                                         q_pids,
                                         g_pids,
                                         q_camids,
                                         g_camids,
                                         use_metric_cuhk03=use_metric_cuhk03)

        print('** Results **')
        print('mAP: {:.1%}'.format(mAP))
        print('CMC curve')
        for r in ranks:
            print('Rank-{:<3}: {:.1%}'.format(r, cmc[r - 1]))

        if visrank:
            visualize_ranked_results(
                distmat,
                self.datamanager.return_testdataset_by_name(dataset_name),
                save_dir=osp.join(save_dir, 'visrank-' + str(epoch + 1),
                                  dataset_name),
                topk=visrank_topk)

        return cmc[0]