コード例 #1
0
def test_cross_dataset(config_file,test_dataset, **kwargs):
    cfg.merge_from_file(config_file)
    if kwargs:
        opts = []
        for k,v in kwargs.items():
            opts.append(k)
            opts.append(v)
        cfg.merge_from_list(opts)
    cfg.freeze()
    
    PersonReID_Dataset_Downloader('./datasets',cfg.DATASETS.NAMES)
    _, _, _, num_classes = data_loader(cfg,cfg.DATASETS.NAMES)
    
    PersonReID_Dataset_Downloader('./datasets',test_dataset)
    _, val_loader, num_query, _ = data_loader(cfg,test_dataset)
    
    re_ranking=cfg.RE_RANKING
    
    if not re_ranking:
        logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR,
                             cfg.DATASETS.NAMES+'->'+test_dataset)
        logger.info("Test Results:")
    else:
        logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR,
                             cfg.DATASETS.NAMES+'->'+test_dataset+'_re-ranking')
        logger.info("Re-Ranking Test Results:") 
        
    device = torch.device(cfg.DEVICE)
    
    model = getattr(models, cfg.MODEL.NAME)(num_classes)
    model.load(cfg.OUTPUT_DIR,cfg.TEST.LOAD_EPOCH)
    model = model.eval()
    
    all_feats = []
    all_pids = []
    all_camids = []
    
    since = time.time()
    for data in tqdm(val_loader, desc='Feature Extraction', leave=False):
        model.eval()
        with torch.no_grad():
            images, pids, camids = data
            if device:
                model.to(device)
                images = images.to(device)
            
            feats = model(images)

        all_feats.append(feats)
        all_pids.extend(np.asarray(pids))
        all_camids.extend(np.asarray(camids))

    cmc, mAP = evaluation(all_feats,all_pids,all_camids,num_query,re_ranking)

    logger.info("mAP: {:.1%}".format(mAP))
    for r in [1, 5, 10]:
        logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
       
    test_time = time.time() - since
    logger.info('Testing complete in {:.0f}m {:.0f}s'.format(test_time // 60, test_time % 60))
コード例 #2
0
def train(config_file, **kwargs):
    cfg.merge_from_file(config_file)
    if kwargs:
        opts = []
        for k, v in kwargs.items():
            opts.append(k)
            opts.append(v)
        cfg.merge_from_list(opts)
    cfg.freeze()

    PersonReID_Dataset_Downloader(cfg.DATASETS.STORE_DIR, cfg.DATASETS.NAMES)

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger = make_logger("Reid_Baseline", output_dir, 'log')
    logger.info("Using {} GPUS".format(1))
    logger.info("Loaded configuration file {}".format(config_file))
    logger.info("Running with config:\n{}".format(cfg))

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = torch.device(cfg.DEVICE)
    epochs = cfg.SOLVER.MAX_EPOCHS

    train_loader, val_loader, num_query, num_classes = data_loader(
        cfg, cfg.DATASETS.NAMES)

    model = getattr(models, cfg.MODEL.NAME)(num_classes, cfg.MODEL.LAST_STRIDE,
                                            cfg.MODEL.POOL)
    optimizer = make_optimizer(cfg, model)
    scheduler = make_scheduler(cfg, optimizer)
    loss_fn = make_loss(cfg)

    logger.info("Start training")
    since = time.time()
    for epoch in range(epochs):
        count = 0
        running_loss = 0.0
        running_acc = 0
        for data in tqdm(train_loader, desc='Iteration', leave=False):
            model.train()
            images, labels = data
            if device:
                model.to(device)
                images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            scores, feats = model(images)
            loss = loss_fn(scores, feats, labels)

            loss.backward()
            optimizer.step()

            count = count + 1
            running_loss += loss.item()
            running_acc += (
                scores[0].max(1)[1] == labels).float().mean().item()

        logger.info(
            "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
            .format(epoch + 1, count, len(train_loader), running_loss / count,
                    running_acc / count,
                    scheduler.get_lr()[0]))
        scheduler.step()

        if (epoch + 1) % checkpoint_period == 0:
            model.cpu()
            model.save(output_dir, epoch + 1)

        # Validation
        if (epoch + 1) % eval_period == 0:
            all_feats = []
            all_pids = []
            all_camids = []
            for data in tqdm(val_loader,
                             desc='Feature Extraction',
                             leave=False):
                model.eval()
                with torch.no_grad():
                    images, pids, camids = data
                    if device:
                        model.to(device)
                        images = images.to(device)

                    feats = model(images)

                all_feats.append(feats)
                all_pids.extend(np.asarray(pids))
                all_camids.extend(np.asarray(camids))

            cmc, mAP = evaluation(all_feats, all_pids, all_camids, num_query)
            logger.info("Validation Results - Epoch: {}".format(epoch + 1))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))

    time_elapsed = time.time() - since
    logger.info('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logger.info('-' * 10)
コード例 #3
0
    def __init__(self, config_file, epoch_label, **kwargs):
        """
        Validation set is split into two parts - query (probe) and gallery (to be searched), based on num_query.

        ::Return: Initialize a file 'model_epoch.mtch':
                matching matrix M of num_query x num_gallery. M_ij is 1 <=> ith query is matched at rank j.
        """

        cfg.merge_from_file(config_file)

        if kwargs:
            opts = []
            for k, v in kwargs.items():
                opts.append(k)
                opts.append(v)
            cfg.merge_from_list(opts)
        cfg.freeze()
        self.cfg = cfg

        device = torch.device(cfg.DEVICE)
        output_dir = cfg.OUTPUT_DIR
        epoch = epoch_label
        re_ranking = cfg.RE_RANKING
        if not os.path.exists(output_dir):
            raise OSError('Output directory does not exist.')
        save_filename = (cfg.MODEL.NAME + '_epo%s.mtch' % epoch_label)
        self._filepath = os.path.join(output_dir, save_filename)

        if os.path.exists(self._filepath):
            print('Loading matches file...')
            self.data = np.load(self._filepath)
            train_loader, val_loader, num_query, num_classes = data_loader(
                cfg, cfg.DATASETS.NAMES)
            self.dataset = val_loader.dataset
            print('Matches loaded.')
        else:
            print('Creating matches file...')
            PersonReID_Dataset_Downloader(cfg.DATASETS.STORE_DIR,
                                          cfg.DATASETS.NAMES)

            train_loader, val_loader, num_query, num_classes = data_loader(
                cfg, cfg.DATASETS.NAMES)

            # load model
            model = getattr(models, cfg.MODEL.NAME)(num_classes)
            model.load(output_dir, epoch)
            model.eval()

            all_feats = []
            all_pids = []
            all_camids = []
            for data in tqdm(val_loader,
                             desc='Feature Extraction',
                             leave=False):
                with torch.no_grad():

                    images, pids, camids = data

                    if device:
                        model.to(device)
                        images = images.to(device)

                    feats = model(images)

                all_feats.append(feats)
                all_pids.extend(np.asarray(pids))
                all_camids.extend(np.asarray(camids))

            all_feats = torch.cat(all_feats, dim=0)
            # query
            qf = all_feats[:num_query]
            q_pids = np.asarray(all_pids[:num_query])
            q_camids = np.asarray(all_camids[:num_query])

            # gallery
            gf = all_feats[num_query:]
            g_pids = np.asarray(all_pids[num_query:])
            g_camids = np.asarray(all_camids[num_query:])

            if re_ranking:
                raise NotImplementedError()
            else:
                m, n = qf.shape[0], gf.shape[0]
                distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                            torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
                distmat.addmm_(1, -2, qf, gf.t())
                distmat = distmat.cpu().numpy()

            indices = np.argsort(distmat, axis=1)
            # matches = np.repeat(g_pids.reshape([1, n]), m, axis=0) == q_pids[:, np.newaxis]
            ranked_matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(
                np.int32)

            data = {
                'q_pids': q_pids,
                'g_pids': g_pids,
                'q_camids': q_camids,
                'g_camids': g_camids,
                'ranked_matches': ranked_matches,
                # 'matches': matches,
                'indices': indices,
            }

            # save as .mtch
            with open(self._filepath, 'wb') as f:
                np.savez(f, **data)

            print('Matches created.')

            self.data = data
            self.dataset = val_loader.dataset
コード例 #4
0
def test(config_file, **kwargs):
    cfg.merge_from_file(config_file)
    if kwargs:
        opts = []
        for k,v in kwargs.items():
            opts.append(k)
            opts.append(v)
        cfg.merge_from_list(opts)
    cfg.freeze()
    
    re_ranking=cfg.RE_RANKING
    
    PersonReID_Dataset_Downloader('./datasets',cfg.DATASETS.NAMES)
    if not re_ranking:
        logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR,'result')
        logger.info("Test Results:")
    else:
        logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR,'result_re-ranking')
        logger.info("Re-Ranking Test Results:") 
    
    device = torch.device(cfg.DEVICE)
    
    _, val_loader, num_query, num_classes = data_loader(cfg,cfg.DATASETS.NAMES)
    
    model = getattr(models, cfg.MODEL.NAME)(num_classes)
    model.load(cfg.OUTPUT_DIR,cfg.TEST.LOAD_EPOCH)
    if device:
        model.to(device) 
    model = model.eval()

    all_feats = []
    all_pids = []
    all_camids = []
    all_imgs = []
    
    for data in tqdm(val_loader, desc='Feature Extraction', leave=False):
        with torch.no_grad():
            images, pids, camids = data
            all_imgs.extend(images.numpy())
            if device:
                model.to(device) 
                images = images.to(device)
            
            feats = model(images)

        all_feats.append(feats)
        all_pids.extend(np.asarray(pids))
        all_camids.extend(np.asarray(camids))

    all_feats = torch.cat(all_feats, dim=0)
    # query
    qf = all_feats[:num_query]
    q_pids = np.asarray(all_pids[:num_query])
    q_camids = np.asarray(all_camids[:num_query])
    q_imgs = all_imgs[:num_query]
    # gallery
    gf = all_feats[num_query:]
    g_pids = np.asarray(all_pids[num_query:])
    g_camids = np.asarray(all_camids[num_query:])
    g_imgs = all_imgs[num_query:]

    if not re_ranking::
        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                  torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.cpu().numpy()
    else:
        print('Calculating Distance')
        q_g_dist = np.dot(qf.data.cpu(), np.transpose(gf.data.cpu()))
        q_q_dist = np.dot(qf.data.cpu(), np.transpose(qf.data.cpu()))
        g_g_dist = np.dot(gf.data.cpu(), np.transpose(gf.data.cpu()))
        print('Re-ranking:')
        distmat= re_ranking(q_g_dist, q_q_dist, g_g_dist)

    indices = np.argsort(distmat, axis=1)

    mean=cfg.INPUT.PIXEL_MEAN
    std=cfg.INPUT.PIXEL_STD
    top_k = 7
    for i in range(num_query):
        # get query pid and camid
        q_pid = q_pids[i]
        q_camid = q_camids[i]

        # remove gallery samples that have the same pid and camid with query
        order = indices[i]
        remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
        keep = np.invert(remove)
        # binary vector, positions with value 1 are correct matches
        true_index = indices[i][keep]

        plt.title("top5 query",fontsize=15)
        plt.subplot(181)
        img = np.clip(q_imgs[i].transpose(1,2,0)*std+mean,0.0,1.0)
        plt.imshow(img)
        for j in range(top_k):
            plt.subplot(182+j)
            img = np.clip(g_imgs[true_index[j]].transpose(1,2,0)*std+mean,0.0,1.0)
            plt.imshow(img)
        plt.savefig("./show/{}.jpg".format(i))
            
    logger.info('Testing complete')