def validate(data_loader, G, F1, F2, args): batch_time = AverageMeter('Time', ':6.3f') top1_1 = AverageMeter('Acc_1', ':6.2f') top1_2 = AverageMeter('Acc_2', ':6.2f') progress = ProgressMeter(len(data_loader), [batch_time, top1_1, top1_2], prefix='Test: ') G.eval() F1.eval() F2.eval() if args.per_class_eval: classes = data_loader.dataset.classes confmat = ConfusionMatrix(len(classes)) else: confmat = None with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(data_loader): images = images.to(device) target = target.to(device) g = G(images) y1, y2 = F1(g), F2(g) acc1, = accuracy(y1, target) acc2, = accuracy(y2, target) if confmat: confmat.update(target, y1.argmax(1)) top1_1.update(acc1.item(), images.size(0)) top1_2.update(acc2.item(), images.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}'.format(top1_1=top1_1, top1_2=top1_2)) if confmat: print(confmat.format(classes)) return top1_1.avg, top1_2.avg
def validate(dataloader, target_iter, classifier, device, args): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter(len(target_iter), [batch_time, losses, top1, top5], prefix='Test: ') classifier.eval() if args.per_class_eval: classes = dataloader.dataset.classes confmat = ConfusionMatrix(len(classes)) else: confmat = None with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(dataloader): images, target = images.to(device), target.to(device) output, _ = classifier(images) loss = F.cross_entropy(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) if confmat: confmat.update(target, output.argmax(1)) losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) top5.update(acc5.item(), images.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) if confmat: print(confmat.format(classes)) return top1.avg