def test_accuracy(): """Test if accuracy drops too low.""" model = torch.nn.Sequential( Flatten(), torch.nn.Linear(28 * 28, 128), torch.nn.ReLU(), torch.nn.Linear(128, 64), torch.nn.Linear(64, 10), ) datasets = { "train": MNIST("./data", transform=ToTensor(), download=True), "valid": MNIST("./data", transform=ToTensor(), train=False), } dataloaders = { k: torch.utils.data.DataLoader(d, batch_size=32) for k, d in datasets.items() } optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) runner = SupervisedRunner() runner.train( model=model, optimizer=optimizer, loaders=dataloaders, callbacks=[AccuracyCallback(target_key="targets", input_key="logits")], num_epochs=1, criterion=torch.nn.CrossEntropyLoss(), valid_loader="valid", valid_metric="accuracy01", minimize_valid_metric=False, ) accuracy_before = _evaluate_loader_accuracy(runner, dataloaders["valid"]) q_model = quantize_model(model) runner.model = q_model accuracy_after = _evaluate_loader_accuracy(runner, dataloaders["valid"]) assert abs(accuracy_before - accuracy_after) < 0.01
def test_mnist(): trainset = MNIST( "./data", train=False, download=True, transform=ToTensor(), ) testset = MNIST( "./data", train=False, download=True, transform=ToTensor(), ) loaders = { "train": DataLoader(trainset, batch_size=32), "valid": DataLoader(testset, batch_size=64), } model = nn.Sequential(Flatten(), nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10)) def objective(trial): lr = trial.suggest_loguniform("lr", 1e-3, 1e-1) optimizer = torch.optim.Adam(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() runner = dl.SupervisedRunner() runner.train( model=model, loaders=loaders, criterion=criterion, optimizer=optimizer, callbacks=[ OptunaCallback(trial), AccuracyCallback(num_classes=10), ], num_epochs=10, main_metric="accuracy01", minimize_metric=False, ) return runner.best_valid_metrics[runner.main_metric] study = optuna.create_study( direction="maximize", pruner=optuna.pruners.MedianPruner(n_startup_trials=1, n_warmup_steps=0, interval_steps=1), ) study.optimize(objective, n_trials=5, timeout=300) assert True
def test_api(): """Test if model can be quantize through API""" model = torch.nn.Sequential( Flatten(), torch.nn.Linear(28 * 28, 128), torch.nn.ReLU(), torch.nn.Linear(128, 64), torch.nn.Linear(64, 10), ) q_model = quantize_model(model) torch.save(model.state_dict(), "model.pth") torch.save(q_model.state_dict(), "q_model.pth") model_size = os.path.getsize("model.pth") q_model_size = os.path.getsize("q_model.pth") assert q_model_size * 3.8 < model_size os.remove("model.pth") os.remove("q_model.pth")
def main(args): train_dataset = TorchvisionDatasetWrapper( MNIST(root="./", download=True, train=True, transform=ToTensor()) ) val_dataset = TorchvisionDatasetWrapper( MNIST(root="./", download=True, train=False, transform=ToTensor()) ) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64) loaders = {"train": train_dataloader, "valid": val_dataloader} utils.set_global_seed(args.seed) net = nn.Sequential( Flatten(), nn.Linear(28 * 28, 300), nn.ReLU(), nn.Linear(300, 100), nn.ReLU(), nn.Linear(100, 10), ) initial_state_dict = net.state_dict() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters()) if args.device is not None: engine = dl.DeviceEngine(args.device) else: engine = None if args.vanilla_pruning: runner = dl.SupervisedRunner(engine=engine) runner.train( model=net, criterion=criterion, optimizer=optimizer, loaders=loaders, callbacks=[ dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10), ], logdir="./logdir", num_epochs=args.num_epochs, load_best_on_end=True, valid_metric="accuracy01", minimize_valid_metric=False, valid_loader="valid", ) pruning_fn = partial( utils.pruning.prune_model, pruning_fn=args.pruning_method, amount=args.amount, keys_to_prune=["weights"], dim=args.dim, l_norm=args.n, ) acc, amount = validate_model( runner, pruning_fn=pruning_fn, loader=loaders["valid"], num_sessions=args.num_sessions ) torch.save(acc, "accuracy.pth") torch.save(amount, "amount.pth") else: runner = PruneRunner(num_sessions=args.num_sessions, engine=engine) callbacks = [ dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10), dl.PruningCallback( args.pruning_method, keys_to_prune=["weight"], amount=args.amount, remove_reparametrization_on_stage_end=False, ), dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"), dl.OptimizerCallback(metric_key="loss"), ] if args.lottery_ticket: callbacks.append(LotteryTicketCallback(initial_state_dict=initial_state_dict)) if args.kd: net.load_state_dict(torch.load(args.state_dict)) callbacks.append( PrepareForFinePruningCallback(probability_shift=args.probability_shift) ) callbacks.append(KLDivCallback(temperature=4, student_logits_key="logits")) callbacks.append( MetricAggregationCallback( prefix="loss", metrics={"loss": 0.1, "kl_div_loss": 0.9}, mode="weighted_sum" ) ) runner.train( model=net, criterion=criterion, optimizer=optimizer, loaders=loaders, callbacks=callbacks, logdir=args.logdir, num_epochs=args.num_epochs, load_best_on_end=True, valid_metric="accuracy01", minimize_valid_metric=False, valid_loader="valid", )