def validate(val_loader, net, epoch, print_pr=False): """ validation for one epoch on the val set """ batch_time = meter.TimeMeter(True) data_time = meter.TimeMeter(True) prec = meter.ClassErrorMeter(topk=[1], accuracy=True) retrieval_map = meter.RetrievalMAPMeter() # testing mode net.eval() total_seen_class = [0 for _ in range(40)] total_right_class = [0 for _ in range(40)] for i, (views, pcs, labels) in enumerate(val_loader): batch_time.reset() views = views.to(device=config.device) pcs = pcs.to(device=config.device) labels = labels.to(device=config.device) preds, fts = net(pcs, views, get_fea=True) # bz x C x H x W # prec.add(preds.data, labels.data) prec.add(preds.data, labels.data) retrieval_map.add(fts.detach() / torch.norm(fts.detach(), 2, 1, True), labels.detach()) for j in range(views.size(0)): total_seen_class[labels.data[j]] += 1 total_right_class[labels.data[j]] += (np.argmax( preds.data, 1)[j] == labels.cpu()[j]) if i % config.print_freq == 0: print( f'Epoch: [{epoch}][{i}/{len(val_loader)}]\t' f'Batch Time {batch_time.value():.3f}\t' f'Epoch Time {data_time.value():.3f}\t' f'Prec@1 {prec.value(1):.3f}\t' f'Mean Class accuracy {(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float))):.3f}' ) mAP = retrieval_map.mAP() print(f' instance accuracy at epoch {epoch}: {prec.value(1)} ') print( f' mean class accuracy at epoch {epoch}: {(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float)))} ' ) print(f' map at epoch {epoch}: {mAP} ') if print_pr: print(f'pr: {retrieval_map.pr()}') return prec.value(1), mAP
def validate(val_loader, net): """ validation for one epoch on the val set """ batch_time = meter.TimeMeter(True) data_time = meter.TimeMeter(True) prec = meter.ClassErrorMeter(topk=[1], accuracy=True) retrieval = meter.RetrievalMAPMeter() ft_all, lbl_all = None, None # testing mode net.eval() for i, (views, labels) in enumerate(val_loader): batch_time.reset() # bz x 12 x 3 x 224 x 224 views = views.to(device=config.device) labels = labels.to(device=config.device) preds, fts = net(views, get_ft=True) # bz x C x H x W prec.add(preds.detach(), labels.detach()) retrieval.add(fts.detach(), labels.detach()) # ft_all = append(ft_all, fts.detach()) # lbl_all = append(lbl_all, labels.detach(), flaten=True) if i % config.print_freq == 0: print(f'[{i}/{len(val_loader)}]\t' f'Batch Time {batch_time.value():.3f}\t' f'Epoch Time {data_time.value():.3f}\t' f'Prec@1 {prec.value(1):.3f}\t') # mAP = cal_map(ft_all, lbl_all) mAP = retrieval.mAP() print(f'mean class accuracy : {prec.value(1)} ') print(f'Retrieval mAP : {mAP} ') return prec.value(1), mAP