コード例 #1
0
ファイル: benchmark.py プロジェクト: ajayarunachalam/glasses
def benchmark(model: nn.Module,
              transform,
              batch_size=64,
              device=device,
              fast: bool = False):

    valid_dataset = ImageNet(root="/home/zuppif/Downloads/ImageNet",
                             split="val",
                             transform=transform)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=12,
        pin_memory=True,
    )

    evaluator = ImageNetEvaluator(model_name="test",
                                  paper_arxiv_id="1905.11946")
    model.eval().to(device)

    num_batches = int(
        math.ceil(len(valid_loader.dataset) / float(valid_loader.batch_size)))

    start = time.time()

    with torch.no_grad():
        pbar = tqdm(np.arange(num_batches), leave=False)
        for i_val, (images, labels) in enumerate(valid_loader):

            images = images.to(device)
            labels = torch.squeeze(labels.to(device))

            net_out = model(images)

            image_ids = [
                get_img_id(img[0]) for img in
                valid_loader.dataset.imgs[i_val *
                                          valid_loader.batch_size:(i_val + 1) *
                                          valid_loader.batch_size]
            ]
            evaluator.add(dict(zip(image_ids, list(net_out.cpu().numpy()))))
            pbar.set_description(f"f1={evaluator.top1.avg:.2f}")
            pbar.update(1)
            if fast:
                break
        pbar.close()
    stop = time.time()
    if fast:
        return evaluator.top1.avg, None, None
    else:
        res = evaluator.get_results()
        return res["Top 1 Accuracy"], res["Top 5 Accuracy"], stop - start
コード例 #2
0
evaluator = ImageNetEvaluator(model_name=model_name,
                              paper_arxiv_id='1905.11946')


def get_img_id(image_name):
    return image_name.split('/')[-1].replace('.JPEG', '')


with torch.no_grad():
    for i, (input, target) in enumerate(test_loader):
        input = input.to(device='cuda', non_blocking=True)
        target = target.to(device='cuda', non_blocking=True)
        output = model(input)
        image_ids = [
            get_img_id(img[0])
            for img in test_loader.dataset.imgs[i *
                                                test_loader.batch_size:(i +
                                                                        1) *
                                                test_loader.batch_size]
        ]
        evaluator.add(dict(zip(image_ids, list(output.cpu().numpy()))))
        if evaluator.cache_exists:
            break

if not is_server():
    print("Results:")
    print(evaluator.get_results())

evaluator.save()