def benchmark(model: nn.Module, transform, batch_size=64, device=device, fast: bool = False): valid_dataset = ImageNet(root="/home/zuppif/Downloads/ImageNet", split="val", transform=transform) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True, ) evaluator = ImageNetEvaluator(model_name="test", paper_arxiv_id="1905.11946") model.eval().to(device) num_batches = int( math.ceil(len(valid_loader.dataset) / float(valid_loader.batch_size))) start = time.time() with torch.no_grad(): pbar = tqdm(np.arange(num_batches), leave=False) for i_val, (images, labels) in enumerate(valid_loader): images = images.to(device) labels = torch.squeeze(labels.to(device)) net_out = model(images) image_ids = [ get_img_id(img[0]) for img in valid_loader.dataset.imgs[i_val * valid_loader.batch_size:(i_val + 1) * valid_loader.batch_size] ] evaluator.add(dict(zip(image_ids, list(net_out.cpu().numpy())))) pbar.set_description(f"f1={evaluator.top1.avg:.2f}") pbar.update(1) if fast: break pbar.close() stop = time.time() if fast: return evaluator.top1.avg, None, None else: res = evaluator.get_results() return res["Top 1 Accuracy"], res["Top 5 Accuracy"], stop - start
evaluator = ImageNetEvaluator(model_name=model_name, paper_arxiv_id='1905.11946') def get_img_id(image_name): return image_name.split('/')[-1].replace('.JPEG', '') with torch.no_grad(): for i, (input, target) in enumerate(test_loader): input = input.to(device='cuda', non_blocking=True) target = target.to(device='cuda', non_blocking=True) output = model(input) image_ids = [ get_img_id(img[0]) for img in test_loader.dataset.imgs[i * test_loader.batch_size:(i + 1) * test_loader.batch_size] ] evaluator.add(dict(zip(image_ids, list(output.cpu().numpy())))) if evaluator.cache_exists: break if not is_server(): print("Results:") print(evaluator.get_results()) evaluator.save()