Пример #1
0
def validate(val_loader, net, epoch, print_pr=False):
    """
    validation for one epoch on the val set
    """
    batch_time = meter.TimeMeter(True)
    data_time = meter.TimeMeter(True)
    prec = meter.ClassErrorMeter(topk=[1], accuracy=True)
    retrieval_map = meter.RetrievalMAPMeter()

    # testing mode
    net.eval()

    total_seen_class = [0 for _ in range(40)]
    total_right_class = [0 for _ in range(40)]

    for i, (views, pcs, labels) in enumerate(val_loader):
        batch_time.reset()

        views = views.to(device=config.device)
        pcs = pcs.to(device=config.device)
        labels = labels.to(device=config.device)

        preds, fts = net(pcs, views, get_fea=True)  # bz x C x H x W

        # prec.add(preds.data, labels.data)

        prec.add(preds.data, labels.data)
        retrieval_map.add(fts.detach() / torch.norm(fts.detach(), 2, 1, True),
                          labels.detach())
        for j in range(views.size(0)):
            total_seen_class[labels.data[j]] += 1
            total_right_class[labels.data[j]] += (np.argmax(
                preds.data, 1)[j] == labels.cpu()[j])

        if i % config.print_freq == 0:
            print(
                f'Epoch: [{epoch}][{i}/{len(val_loader)}]\t'
                f'Batch Time {batch_time.value():.3f}\t'
                f'Epoch Time {data_time.value():.3f}\t'
                f'Prec@1 {prec.value(1):.3f}\t'
                f'Mean Class accuracy {(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float))):.3f}'
            )

    mAP = retrieval_map.mAP()
    print(f' instance accuracy at epoch {epoch}: {prec.value(1)} ')
    print(
        f' mean class accuracy at epoch {epoch}: {(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float)))} '
    )
    print(f' map at epoch {epoch}: {mAP} ')
    if print_pr:
        print(f'pr: {retrieval_map.pr()}')
    return prec.value(1), mAP
Пример #2
0
def validate(val_loader, net):
    """
    validation for one epoch on the val set
    """
    batch_time = meter.TimeMeter(True)
    data_time = meter.TimeMeter(True)
    prec = meter.ClassErrorMeter(topk=[1], accuracy=True)
    retrieval = meter.RetrievalMAPMeter()
    ft_all, lbl_all = None, None

    # testing mode
    net.eval()

    for i, (views, labels) in enumerate(val_loader):
        batch_time.reset()
        # bz x 12 x 3 x 224 x 224
        views = views.to(device=config.device)
        labels = labels.to(device=config.device)

        preds, fts = net(views, get_ft=True)  # bz x C x H x W

        prec.add(preds.detach(), labels.detach())
        retrieval.add(fts.detach(), labels.detach())
        # ft_all = append(ft_all, fts.detach())
        # lbl_all = append(lbl_all, labels.detach(), flaten=True)

        if i % config.print_freq == 0:
            print(f'[{i}/{len(val_loader)}]\t'
                  f'Batch Time {batch_time.value():.3f}\t'
                  f'Epoch Time {data_time.value():.3f}\t'
                  f'Prec@1 {prec.value(1):.3f}\t')

    # mAP = cal_map(ft_all, lbl_all)
    mAP = retrieval.mAP()
    print(f'mean class accuracy : {prec.value(1)} ')
    print(f'Retrieval mAP : {mAP} ')
    return prec.value(1), mAP