コード例 #1
0
def test_accuracy():
    """Test if accuracy drops too low."""
    model = torch.nn.Sequential(
        torch.nn.Flatten(),
        torch.nn.Linear(28 * 28, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.Linear(64, 10),
    )
    datasets = {
        "train": MNIST(DATA_ROOT, train=False),
        "valid": MNIST(DATA_ROOT, train=False),
    }
    dataloaders = {
        k: torch.utils.data.DataLoader(d, batch_size=32) for k, d in datasets.items()
    }
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    runner = SupervisedRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=dataloaders,
        callbacks=[AccuracyCallback(target_key="targets", input_key="logits")],
        num_epochs=1,
        criterion=torch.nn.CrossEntropyLoss(),
        valid_loader="valid",
        valid_metric="accuracy01",
        minimize_valid_metric=False,
    )
    accuracy_before = _evaluate_loader_accuracy(runner, dataloaders["valid"])
    q_model = quantize_model(model)
    runner.model = q_model
    accuracy_after = _evaluate_loader_accuracy(runner, dataloaders["valid"])
    assert abs(accuracy_before - accuracy_after) < 0.01
コード例 #2
0
def test_api():
    """Test if model can be quantize through API"""
    model = torch.nn.Sequential(
        torch.nn.Flatten(),
        torch.nn.Linear(28 * 28, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.Linear(64, 10),
    )
    q_model = quantize_model(model)
    torch.save(model.state_dict(), "model.pth")
    torch.save(q_model.state_dict(), "q_model.pth")
    model_size = os.path.getsize("model.pth")
    q_model_size = os.path.getsize("q_model.pth")
    assert q_model_size * 3.8 < model_size
    os.remove("model.pth")
    os.remove("q_model.pth")