def test_accuracy():
    """Test if accuracy drops too low."""
    model = torch.nn.Sequential(
        torch.nn.Flatten(),
        torch.nn.Linear(28 * 28, 128),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 64),
        torch.nn.Linear(64, 10),
    )
    datasets = {
        "train": MNIST(DATA_ROOT, train=False),
        "valid": MNIST(DATA_ROOT, train=False),
    }
    dataloaders = {
        k: torch.utils.data.DataLoader(d, batch_size=32) for k, d in datasets.items()
    }
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    runner = SupervisedRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=dataloaders,
        callbacks=[AccuracyCallback(target_key="targets", input_key="logits")],
        num_epochs=1,
        criterion=torch.nn.CrossEntropyLoss(),
        valid_loader="valid",
        valid_metric="accuracy01",
        minimize_valid_metric=False,
    )
    accuracy_before = _evaluate_loader_accuracy(runner, dataloaders["valid"])
    q_model = quantize_model(model)
    runner.model = q_model
    accuracy_after = _evaluate_loader_accuracy(runner, dataloaders["valid"])
    assert abs(accuracy_before - accuracy_after) < 0.01
Esempio n. 2
0
 def get_callbacks(self, stage: str):
     return {
         "criterion":
         CriterionCallback(metric_key="loss",
                           input_key="logits",
                           target_key="targets"),
         "accuracy":
         AccuracyCallback(input_key="logits",
                          target_key="targets",
                          topk_args=(1, )),
         "auc":
         AUCCallback(input_key="scores", target_key="targets_onehot"),
         "classification":
         PrecisionRecallF1SupportCallback(
             input_key="logits",
             target_key="targets",
             num_classes=4,
         ),
         # "optimizer": OptimizerCallback(metric_key="loss"),
         # "checkpoint": CheckpointCallback(
         #     self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3
         # ),
         # "verbose": TqdmCallback(),
         "counter":
         CounterCallback(),
     }
Esempio n. 3
0
def test_pruning():
    from catalyst.callbacks import (
        AccuracyCallback,
        ControlFlowCallback,
        CriterionCallback,
        OptimizerCallback,
        PruningCallback,
    )
    from catalyst.contrib.datasets import MNIST
    import torch
    from torch.utils.data import DataLoader
    from torchvision.transforms import ToTensor

    from compressors.distillation.callbacks import KLDivCallback, MetricAggregationCallback
    from compressors.models import MLP
    from compressors.pruning.runners import PruneRunner
    from compressors.utils.data import TorchvisionDatasetWrapper as Wrp

    model = MLP(num_layers=3)

    datasets = {
        "train":
        Wrp(MNIST("./data", train=True, download=True, transform=ToTensor())),
        "valid":
        Wrp(MNIST("./data", train=False, transform=ToTensor())),
    }

    loaders = {
        dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
        for dl_key, dataset in datasets.items()
    }

    optimizer = torch.optim.Adam(model.parameters())

    runner = PruneRunner(num_sessions=10)

    runner.train(model=model,
                 loaders=loaders,
                 optimizer=optimizer,
                 criterion=torch.nn.CrossEntropyLoss(),
                 callbacks=[
                     PruningCallback(
                         pruning_fn="l1_unstructured",
                         amount=0.2,
                         remove_reparametrization_on_stage_end=False,
                     ),
                     OptimizerCallback(metric_key="loss"),
                     CriterionCallback(input_key="logits",
                                       target_key="targets",
                                       metric_key="loss"),
                     AccuracyCallback(input_key="logits",
                                      target_key="targets"),
                 ],
                 logdir="./pruned_model",
                 valid_loader="valid",
                 valid_metric="accuracy",
                 minimize_valid_metric=False,
                 check=True)
Esempio n. 4
0
def test_distil():
    from itertools import chain

    from catalyst.callbacks import AccuracyCallback, OptimizerCallback
    from catalyst.contrib.datasets import MNIST
    import torch
    from torch.utils.data import DataLoader
    from torchvision import transforms as T

    from compressors.distillation.runners import EndToEndDistilRunner
    from compressors.models import MLP
    from compressors.utils.data import TorchvisionDatasetWrapper as Wrp

    teacher = MLP(num_layers=4)
    student = MLP(num_layers=3)

    datasets = {
        "train":
        Wrp(MNIST("./data", train=True, download=True,
                  transform=T.ToTensor())),
        "valid":
        Wrp(MNIST("./data", train=False, transform=T.ToTensor())),
    }

    loaders = {
        dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
        for dl_key, dataset in datasets.items()
    }

    optimizer = torch.optim.Adam(
        chain(teacher.parameters(), student.parameters()))

    runner = EndToEndDistilRunner(hidden_state_loss="mse",
                                  num_train_teacher_epochs=5)

    runner.train(
        model=torch.nn.ModuleDict({
            "teacher": teacher,
            "student": student
        }),
        loaders=loaders,
        optimizer=optimizer,
        num_epochs=4,
        callbacks=[
            OptimizerCallback(metric_key="loss"),
            AccuracyCallback(input_key="logits", target_key="targets"),
        ],
        valid_metric="accuracy01",
        minimize_valid_metric=False,
        logdir="./logs",
        valid_loader="valid",
        criterion=torch.nn.CrossEntropyLoss(),
        check=True,
    )
Esempio n. 5
0
def main(args):

    set_global_seed(42)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    datasets = {
        "train":
        Wrp(
            CIFAR100(root=".",
                     train=True,
                     download=True,
                     transform=transform_train)),
        "valid":
        Wrp(CIFAR100(root=".", train=False, transform=transform_test)),
    }

    loaders = {
        k: DataLoader(v,
                      batch_size=args.batch_size,
                      shuffle=k == "train",
                      num_workers=2)
        for k, v in datasets.items()
    }
    teacher_model = NAME2MODEL[args.teacher](num_classes=100)
    if args.teacher_path is None:
        teacher_sd = load_state_dict_from_url(NAME2URL[args.teacher])
        teacher_model.load_state_dict(teacher_sd)
    else:
        unpack_checkpoint(torch.load(args.teacher_path), model=teacher_model)
    student_model = NAME2MODEL[args.student](num_classes=100)

    optimizer = torch.optim.SGD(student_model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     [150, 180, 210],
                                                     gamma=0.1)

    runner = DistilRunner(apply_probability_shift=args.probability_shift)
    runner.train(model={
        "teacher": teacher_model,
        "student": student_model
    },
                 loaders=loaders,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 valid_metric="accuracy",
                 minimize_valid_metric=False,
                 logdir=args.logdir,
                 callbacks=[
                     ControlFlowCallback(AttentionHiddenStatesCallback(),
                                         loaders="train"),
                     ControlFlowCallback(KLDivCallback(temperature=4),
                                         loaders="train"),
                     CriterionCallback(input_key="s_logits",
                                       target_key="targets",
                                       metric_key="cls_loss"),
                     ControlFlowCallback(
                         MetricAggregationCallback(
                             prefix="loss",
                             metrics={
                                 "attention_loss": args.beta,
                                 "kl_div_loss": args.alpha,
                                 "cls_loss": 1 - args.alpha,
                             },
                             mode="weighted_sum",
                         ),
                         loaders="train",
                     ),
                     AccuracyCallback(input_key="s_logits",
                                      target_key="targets"),
                     OptimizerCallback(metric_key="loss", model_key="student"),
                     SchedulerCallback(),
                 ],
                 valid_loader="valid",
                 num_epochs=args.num_epochs,
                 criterion=torch.nn.CrossEntropyLoss(),
                 seed=args.seed)
datasets = {
    "train": Wrp(MNIST("./data", train=True, download=True, transform=T.ToTensor())),
    "valid": Wrp(MNIST("./data", train=False, transform=T.ToTensor())),
}

loaders = {
    dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
    for dl_key, dataset in datasets.items()
}

optimizer = torch.optim.Adam(chain(teacher.parameters(), student.parameters()))

runner = EndToEndDistilRunner(hidden_state_loss="mse", num_train_teacher_epochs=5)

runner.train(
    model=torch.nn.ModuleDict({"teacher": teacher, "student": student}),
    loaders=loaders,
    optimizer=optimizer,
    num_epochs=4,
    callbacks=[
        OptimizerCallback(metric_key="loss"),
        AccuracyCallback(input_key="logits", target_key="targets"),
    ],
    valid_metric="accuracy01",
    minimize_valid_metric=False,
    logdir="./logs",
    valid_loader="valid",
    criterion=torch.nn.CrossEntropyLoss(),
    check=True,
)
loaders = {
    dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
    for dl_key, dataset in datasets.item()
}

optimizer = torch.optim.Adam(teacher.parameters(), lr=1e-2)

runner = SupervisedRunner()

runner.train(
    model=teacher,
    loaders=loaders,
    optimizer=optimizer,
    criterion=nn.CrossEntropyLoss(),
    callbacks=[AccuracyCallback(input_key="logits", target_key="targets")],
    valid_metric="accuracy01",
    minimize_valid_metric=False,
    num_epochs=5,
)

mse_callback = MSEHiddenStatesCallback()
select_last_hidden_states = HiddenStatesSelectCallback(layers=[2, 3])
kl_callback = KLDivCallback()

runner = DistilRunner(output_hidden_states=True)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-2)

runner.train(
    model={
        "teacher": teacher,
Esempio n. 8
0
def main(args):

    set_global_seed(args.seed)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    datasets = {
        "train":
        Wrp(
            CIFAR100(root=".",
                     train=True,
                     download=True,
                     transform=transform_train)),
        "valid":
        Wrp(CIFAR100(root=".", train=False, transform=transform_test)),
    }

    loaders = {
        k: DataLoader(v,
                      batch_size=args.batch_size,
                      shuffle=k == "train",
                      num_workers=2)
        for k, v in datasets.items()
    }
    model = NAME2MODEL[args.model](num_classes=100)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     [150, 180, 210],
                                                     gamma=0.1)

    runner = SupervisedRunner()
    runner.train(model=model,
                 loaders=loaders,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 valid_metric="accuracy",
                 minimize_valid_metric=False,
                 logdir=args.logdir,
                 callbacks=[
                     CriterionCallback(input_key="logits",
                                       target_key="targets",
                                       metric_key="loss"),
                     AccuracyCallback(input_key="logits",
                                      target_key="targets"),
                     OptimizerCallback(metric_key="loss"),
                     SchedulerCallback(),
                 ],
                 valid_loader="valid",
                 num_epochs=args.num_epochs,
                 criterion=torch.nn.CrossEntropyLoss(),
                 seed=args.seed)