Exemple #1
0
def test_accuracy():
    """Test if accuracy drops too low."""
    model = torch.nn.Sequential(
        Flatten(),
        torch.nn.Linear(28 * 28, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.Linear(64, 10),
    )
    datasets = {
        "train": MNIST("./data", transform=ToTensor(), download=True),
        "valid": MNIST("./data", transform=ToTensor(), train=False),
    }
    dataloaders = {
        k: torch.utils.data.DataLoader(d, batch_size=32)
        for k, d in datasets.items()
    }
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    runner = SupervisedRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=dataloaders,
        callbacks=[AccuracyCallback(target_key="targets", input_key="logits")],
        num_epochs=1,
        criterion=torch.nn.CrossEntropyLoss(),
        valid_loader="valid",
        valid_metric="accuracy01",
        minimize_valid_metric=False,
    )
    accuracy_before = _evaluate_loader_accuracy(runner, dataloaders["valid"])
    q_model = quantize_model(model)
    runner.model = q_model
    accuracy_after = _evaluate_loader_accuracy(runner, dataloaders["valid"])
    assert abs(accuracy_before - accuracy_after) < 0.01
def test_mnist():
    trainset = MNIST(
        "./data",
        train=False,
        download=True,
        transform=ToTensor(),
    )
    testset = MNIST(
        "./data",
        train=False,
        download=True,
        transform=ToTensor(),
    )
    loaders = {
        "train": DataLoader(trainset, batch_size=32),
        "valid": DataLoader(testset, batch_size=64),
    }
    model = nn.Sequential(Flatten(), nn.Linear(784, 128), nn.ReLU(),
                          nn.Linear(128, 10))

    def objective(trial):
        lr = trial.suggest_loguniform("lr", 1e-3, 1e-1)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        runner = dl.SupervisedRunner()
        runner.train(
            model=model,
            loaders=loaders,
            criterion=criterion,
            optimizer=optimizer,
            callbacks=[
                OptunaCallback(trial),
                AccuracyCallback(num_classes=10),
            ],
            num_epochs=10,
            main_metric="accuracy01",
            minimize_metric=False,
        )
        return runner.best_valid_metrics[runner.main_metric]

    study = optuna.create_study(
        direction="maximize",
        pruner=optuna.pruners.MedianPruner(n_startup_trials=1,
                                           n_warmup_steps=0,
                                           interval_steps=1),
    )
    study.optimize(objective, n_trials=5, timeout=300)
    assert True
Exemple #3
0
def test_api():
    """Test if model can be quantize through API"""
    model = torch.nn.Sequential(
        Flatten(),
        torch.nn.Linear(28 * 28, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.Linear(64, 10),
    )
    q_model = quantize_model(model)
    torch.save(model.state_dict(), "model.pth")
    torch.save(q_model.state_dict(), "q_model.pth")
    model_size = os.path.getsize("model.pth")
    q_model_size = os.path.getsize("q_model.pth")
    assert q_model_size * 3.8 < model_size
    os.remove("model.pth")
    os.remove("q_model.pth")
Exemple #4
0
def main(args):
    train_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=True, transform=ToTensor())
    )
    val_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=False, transform=ToTensor())
    )

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64)
    loaders = {"train": train_dataloader, "valid": val_dataloader}
    utils.set_global_seed(args.seed)
    net = nn.Sequential(
        Flatten(),
        nn.Linear(28 * 28, 300),
        nn.ReLU(),
        nn.Linear(300, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
    )
    initial_state_dict = net.state_dict()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    if args.device is not None:
        engine = dl.DeviceEngine(args.device)
    else:
        engine = None
    if args.vanilla_pruning:
        runner = dl.SupervisedRunner(engine=engine)

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            ],
            logdir="./logdir",
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )
        pruning_fn = partial(
            utils.pruning.prune_model,
            pruning_fn=args.pruning_method,
            amount=args.amount,
            keys_to_prune=["weights"],
            dim=args.dim,
            l_norm=args.n,
        )
        acc, amount = validate_model(
            runner, pruning_fn=pruning_fn, loader=loaders["valid"], num_sessions=args.num_sessions
        )
        torch.save(acc, "accuracy.pth")
        torch.save(amount, "amount.pth")

    else:
        runner = PruneRunner(num_sessions=args.num_sessions, engine=engine)
        callbacks = [
            dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            dl.PruningCallback(
                args.pruning_method,
                keys_to_prune=["weight"],
                amount=args.amount,
                remove_reparametrization_on_stage_end=False,
            ),
            dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
        ]
        if args.lottery_ticket:
            callbacks.append(LotteryTicketCallback(initial_state_dict=initial_state_dict))
        if args.kd:
            net.load_state_dict(torch.load(args.state_dict))
            callbacks.append(
                PrepareForFinePruningCallback(probability_shift=args.probability_shift)
            )
            callbacks.append(KLDivCallback(temperature=4, student_logits_key="logits"))
            callbacks.append(
                MetricAggregationCallback(
                    prefix="loss", metrics={"loss": 0.1, "kl_div_loss": 0.9}, mode="weighted_sum"
                )
            )

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=callbacks,
            logdir=args.logdir,
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )