def test_accuracy(): """Test if accuracy drops too low.""" model = torch.nn.Sequential( torch.nn.Flatten(), torch.nn.Linear(28 * 28, 128), torch.nn.ReLU(), torch.nn.Linear(128, 64), torch.nn.Linear(64, 10), ) datasets = { "train": MNIST(DATA_ROOT, train=False), "valid": MNIST(DATA_ROOT, train=False), } dataloaders = { k: torch.utils.data.DataLoader(d, batch_size=32) for k, d in datasets.items() } optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) runner = SupervisedRunner() runner.train( model=model, optimizer=optimizer, loaders=dataloaders, callbacks=[AccuracyCallback(target_key="targets", input_key="logits")], num_epochs=1, criterion=torch.nn.CrossEntropyLoss(), valid_loader="valid", valid_metric="accuracy01", minimize_valid_metric=False, ) accuracy_before = _evaluate_loader_accuracy(runner, dataloaders["valid"]) q_model = quantize_model(model) runner.model = q_model accuracy_after = _evaluate_loader_accuracy(runner, dataloaders["valid"]) assert abs(accuracy_before - accuracy_after) < 0.01
def get_callbacks(self, stage: str): return { "criterion": CriterionCallback(metric_key="loss", input_key="logits", target_key="targets"), "accuracy": AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, )), "auc": AUCCallback(input_key="scores", target_key="targets_onehot"), "classification": PrecisionRecallF1SupportCallback( input_key="logits", target_key="targets", num_classes=4, ), # "optimizer": OptimizerCallback(metric_key="loss"), # "checkpoint": CheckpointCallback( # self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3 # ), # "verbose": TqdmCallback(), "counter": CounterCallback(), }
def test_pruning(): from catalyst.callbacks import ( AccuracyCallback, ControlFlowCallback, CriterionCallback, OptimizerCallback, PruningCallback, ) from catalyst.contrib.datasets import MNIST import torch from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from compressors.distillation.callbacks import KLDivCallback, MetricAggregationCallback from compressors.models import MLP from compressors.pruning.runners import PruneRunner from compressors.utils.data import TorchvisionDatasetWrapper as Wrp model = MLP(num_layers=3) datasets = { "train": Wrp(MNIST("./data", train=True, download=True, transform=ToTensor())), "valid": Wrp(MNIST("./data", train=False, transform=ToTensor())), } loaders = { dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32) for dl_key, dataset in datasets.items() } optimizer = torch.optim.Adam(model.parameters()) runner = PruneRunner(num_sessions=10) runner.train(model=model, loaders=loaders, optimizer=optimizer, criterion=torch.nn.CrossEntropyLoss(), callbacks=[ PruningCallback( pruning_fn="l1_unstructured", amount=0.2, remove_reparametrization_on_stage_end=False, ), OptimizerCallback(metric_key="loss"), CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"), AccuracyCallback(input_key="logits", target_key="targets"), ], logdir="./pruned_model", valid_loader="valid", valid_metric="accuracy", minimize_valid_metric=False, check=True)
def test_distil(): from itertools import chain from catalyst.callbacks import AccuracyCallback, OptimizerCallback from catalyst.contrib.datasets import MNIST import torch from torch.utils.data import DataLoader from torchvision import transforms as T from compressors.distillation.runners import EndToEndDistilRunner from compressors.models import MLP from compressors.utils.data import TorchvisionDatasetWrapper as Wrp teacher = MLP(num_layers=4) student = MLP(num_layers=3) datasets = { "train": Wrp(MNIST("./data", train=True, download=True, transform=T.ToTensor())), "valid": Wrp(MNIST("./data", train=False, transform=T.ToTensor())), } loaders = { dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32) for dl_key, dataset in datasets.items() } optimizer = torch.optim.Adam( chain(teacher.parameters(), student.parameters())) runner = EndToEndDistilRunner(hidden_state_loss="mse", num_train_teacher_epochs=5) runner.train( model=torch.nn.ModuleDict({ "teacher": teacher, "student": student }), loaders=loaders, optimizer=optimizer, num_epochs=4, callbacks=[ OptimizerCallback(metric_key="loss"), AccuracyCallback(input_key="logits", target_key="targets"), ], valid_metric="accuracy01", minimize_valid_metric=False, logdir="./logs", valid_loader="valid", criterion=torch.nn.CrossEntropyLoss(), check=True, )
def main(args): set_global_seed(42) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) datasets = { "train": Wrp( CIFAR100(root=".", train=True, download=True, transform=transform_train)), "valid": Wrp(CIFAR100(root=".", train=False, transform=transform_test)), } loaders = { k: DataLoader(v, batch_size=args.batch_size, shuffle=k == "train", num_workers=2) for k, v in datasets.items() } teacher_model = NAME2MODEL[args.teacher](num_classes=100) if args.teacher_path is None: teacher_sd = load_state_dict_from_url(NAME2URL[args.teacher]) teacher_model.load_state_dict(teacher_sd) else: unpack_checkpoint(torch.load(args.teacher_path), model=teacher_model) student_model = NAME2MODEL[args.student](num_classes=100) optimizer = torch.optim.SGD(student_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [150, 180, 210], gamma=0.1) runner = DistilRunner(apply_probability_shift=args.probability_shift) runner.train(model={ "teacher": teacher_model, "student": student_model }, loaders=loaders, optimizer=optimizer, scheduler=scheduler, valid_metric="accuracy", minimize_valid_metric=False, logdir=args.logdir, callbacks=[ ControlFlowCallback(AttentionHiddenStatesCallback(), loaders="train"), ControlFlowCallback(KLDivCallback(temperature=4), loaders="train"), CriterionCallback(input_key="s_logits", target_key="targets", metric_key="cls_loss"), ControlFlowCallback( MetricAggregationCallback( prefix="loss", metrics={ "attention_loss": args.beta, "kl_div_loss": args.alpha, "cls_loss": 1 - args.alpha, }, mode="weighted_sum", ), loaders="train", ), AccuracyCallback(input_key="s_logits", target_key="targets"), OptimizerCallback(metric_key="loss", model_key="student"), SchedulerCallback(), ], valid_loader="valid", num_epochs=args.num_epochs, criterion=torch.nn.CrossEntropyLoss(), seed=args.seed)
datasets = { "train": Wrp(MNIST("./data", train=True, download=True, transform=T.ToTensor())), "valid": Wrp(MNIST("./data", train=False, transform=T.ToTensor())), } loaders = { dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32) for dl_key, dataset in datasets.items() } optimizer = torch.optim.Adam(chain(teacher.parameters(), student.parameters())) runner = EndToEndDistilRunner(hidden_state_loss="mse", num_train_teacher_epochs=5) runner.train( model=torch.nn.ModuleDict({"teacher": teacher, "student": student}), loaders=loaders, optimizer=optimizer, num_epochs=4, callbacks=[ OptimizerCallback(metric_key="loss"), AccuracyCallback(input_key="logits", target_key="targets"), ], valid_metric="accuracy01", minimize_valid_metric=False, logdir="./logs", valid_loader="valid", criterion=torch.nn.CrossEntropyLoss(), check=True, )
loaders = { dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32) for dl_key, dataset in datasets.item() } optimizer = torch.optim.Adam(teacher.parameters(), lr=1e-2) runner = SupervisedRunner() runner.train( model=teacher, loaders=loaders, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), callbacks=[AccuracyCallback(input_key="logits", target_key="targets")], valid_metric="accuracy01", minimize_valid_metric=False, num_epochs=5, ) mse_callback = MSEHiddenStatesCallback() select_last_hidden_states = HiddenStatesSelectCallback(layers=[2, 3]) kl_callback = KLDivCallback() runner = DistilRunner(output_hidden_states=True) optimizer = torch.optim.Adam(student.parameters(), lr=1e-2) runner.train( model={ "teacher": teacher,
def main(args): set_global_seed(args.seed) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) datasets = { "train": Wrp( CIFAR100(root=".", train=True, download=True, transform=transform_train)), "valid": Wrp(CIFAR100(root=".", train=False, transform=transform_test)), } loaders = { k: DataLoader(v, batch_size=args.batch_size, shuffle=k == "train", num_workers=2) for k, v in datasets.items() } model = NAME2MODEL[args.model](num_classes=100) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [150, 180, 210], gamma=0.1) runner = SupervisedRunner() runner.train(model=model, loaders=loaders, optimizer=optimizer, scheduler=scheduler, valid_metric="accuracy", minimize_valid_metric=False, logdir=args.logdir, callbacks=[ CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"), AccuracyCallback(input_key="logits", target_key="targets"), OptimizerCallback(metric_key="loss"), SchedulerCallback(), ], valid_loader="valid", num_epochs=args.num_epochs, criterion=torch.nn.CrossEntropyLoss(), seed=args.seed)