Example #1
0
 def get_callbacks(self, stage: str):
     return {
         "criterion":
         CriterionCallback(metric_key="loss",
                           input_key="logits",
                           target_key="targets"),
         "optimizer":
         OptimizerCallback(metric_key="loss"),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "checkpoint":
         CheckpointCallback(self._logdir,
                            loader_key="valid",
                            metric_key="loss",
                            minimize=True,
                            save_n_best=3),
         "test_nn_module":
         ModuleTypeChecker(),
         "test_device":
         DeviceCheckCallback(self._device, logger=logger),
         "test_loss_minimization":
         LossMinimizationCallback("loss", logger=logger),
         "test_logits_type":
         TensorTypeChecker("logits"),
         # "loss_type_checker": TensorTypeChecker("loss", True),
     }
Example #2
0
 def get_callbacks(self, stage: str):
     return {
         "criterion":
         CriterionCallback(metric_key="loss",
                           input_key="logits",
                           target_key="targets"),
         "accuracy":
         AccuracyCallback(input_key="logits",
                          target_key="targets",
                          topk_args=(1, )),
         "auc":
         AUCCallback(input_key="scores", target_key="targets_onehot"),
         "classification":
         PrecisionRecallF1SupportCallback(
             input_key="logits",
             target_key="targets",
             num_classes=4,
         ),
         # "optimizer": OptimizerCallback(metric_key="loss"),
         # "checkpoint": CheckpointCallback(
         #     self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3
         # ),
         # "verbose": TqdmCallback(),
         "counter":
         CounterCallback(),
     }
Example #3
0
def test_pruning():
    from catalyst.callbacks import (
        AccuracyCallback,
        ControlFlowCallback,
        CriterionCallback,
        OptimizerCallback,
        PruningCallback,
    )
    from catalyst.contrib.datasets import MNIST
    import torch
    from torch.utils.data import DataLoader
    from torchvision.transforms import ToTensor

    from compressors.distillation.callbacks import KLDivCallback, MetricAggregationCallback
    from compressors.models import MLP
    from compressors.pruning.runners import PruneRunner
    from compressors.utils.data import TorchvisionDatasetWrapper as Wrp

    model = MLP(num_layers=3)

    datasets = {
        "train":
        Wrp(MNIST("./data", train=True, download=True, transform=ToTensor())),
        "valid":
        Wrp(MNIST("./data", train=False, transform=ToTensor())),
    }

    loaders = {
        dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
        for dl_key, dataset in datasets.items()
    }

    optimizer = torch.optim.Adam(model.parameters())

    runner = PruneRunner(num_sessions=10)

    runner.train(model=model,
                 loaders=loaders,
                 optimizer=optimizer,
                 criterion=torch.nn.CrossEntropyLoss(),
                 callbacks=[
                     PruningCallback(
                         pruning_fn="l1_unstructured",
                         amount=0.2,
                         remove_reparametrization_on_stage_end=False,
                     ),
                     OptimizerCallback(metric_key="loss"),
                     CriterionCallback(input_key="logits",
                                       target_key="targets",
                                       metric_key="loss"),
                     AccuracyCallback(input_key="logits",
                                      target_key="targets"),
                 ],
                 logdir="./pruned_model",
                 valid_loader="valid",
                 valid_metric="accuracy",
                 minimize_valid_metric=False,
                 check=True)
Example #4
0
 def get_callbacks(self, stage: str):
     return {
         "criterion": CriterionCallback(
             metric_key="loss", input_key="logits", target_key="targets"
         ),
         "optimizer": OptimizerCallback(metric_key="loss"),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "checkpoint": CheckpointCallback(
             self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3
         ),
         "test_nn_parallel_data_parallel": DataParallelTypeChecker(),
         "test_loss_minimization": LossMinimizationCallback("loss", logger=logger),
         "test_logits_type": OPTTensorTypeChecker("logits", self._opt_level),
     }
 def get_callbacks(self, stage: str):
     return {
         "criterion": CriterionCallback(
             metric_key="loss", input_key="logits", target_key="targets"
         ),
         "optimizer": OptimizerCallback(metric_key="loss"),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "checkpoint": CheckpointCallback(
             self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3
         ),
         "test_nn_parallel_distributed_data_parallel": DistributedDataParallelTypeChecker(),
         "test_loss_minimization": LossMinimizationCallback("loss", logger=logger),
         "test_world_size": WorldSizeCheckCallback(NUM_CUDA_DEVICES, logger=logger),
     }
def test_tracer_callback():
    """
    Tests a feature of `TracingCallback` for model tracing during training
    """
    logdir = "./logs"
    dataset_root = "./data"
    loaders = _get_loaders(root=dataset_root, batch_size=4, num_workers=1)
    images, targets = next(iter(loaders["train"]))
    _, c, h, w = images.shape
    input_shape = (c, h, w)

    model = _TracedNet(input_shape)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters())

    method_name = "forward"
    mode = "eval"
    requires_grad = False
    checkpoint_name = "best"
    opt_level = None

    trace_name = get_trace_name(
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        additional_string=checkpoint_name,
    )
    tracing_path = Path(logdir) / "trace" / trace_name
    criterion_callback = CriterionCallback()
    optimizer_callback = OptimizerCallback()
    tracer_callback = TracingCallback(
        metric="loss",
        minimize=False,
        trace_mode=mode,
        mode=checkpoint_name,
        do_once=True,
        method_name=method_name,
        requires_grad=requires_grad,
        opt_level=opt_level,
    )
    test_callback = _OnStageEndCheckModelTracedCallback(
        path=tracing_path,
        inputs=images,
    )

    callbacks = collections.OrderedDict(
        loss=criterion_callback,
        optimizer=optimizer_callback,
        tracer_callback=tracer_callback,
        test_callback=test_callback,
    )

    runner = SupervisedRunner(input_key="x")
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir=logdir,
        callbacks=callbacks,
        check=True,
        verbose=True,
    )

    shutil.rmtree(logdir)
Example #7
0
def main(args):

    set_global_seed(42)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    datasets = {
        "train":
        Wrp(
            CIFAR100(root=".",
                     train=True,
                     download=True,
                     transform=transform_train)),
        "valid":
        Wrp(CIFAR100(root=".", train=False, transform=transform_test)),
    }

    loaders = {
        k: DataLoader(v,
                      batch_size=args.batch_size,
                      shuffle=k == "train",
                      num_workers=2)
        for k, v in datasets.items()
    }
    teacher_model = NAME2MODEL[args.teacher](num_classes=100)
    if args.teacher_path is None:
        teacher_sd = load_state_dict_from_url(NAME2URL[args.teacher])
        teacher_model.load_state_dict(teacher_sd)
    else:
        unpack_checkpoint(torch.load(args.teacher_path), model=teacher_model)
    student_model = NAME2MODEL[args.student](num_classes=100)

    optimizer = torch.optim.SGD(student_model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     [150, 180, 210],
                                                     gamma=0.1)

    runner = DistilRunner(apply_probability_shift=args.probability_shift)
    runner.train(model={
        "teacher": teacher_model,
        "student": student_model
    },
                 loaders=loaders,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 valid_metric="accuracy",
                 minimize_valid_metric=False,
                 logdir=args.logdir,
                 callbacks=[
                     ControlFlowCallback(AttentionHiddenStatesCallback(),
                                         loaders="train"),
                     ControlFlowCallback(KLDivCallback(temperature=4),
                                         loaders="train"),
                     CriterionCallback(input_key="s_logits",
                                       target_key="targets",
                                       metric_key="cls_loss"),
                     ControlFlowCallback(
                         MetricAggregationCallback(
                             prefix="loss",
                             metrics={
                                 "attention_loss": args.beta,
                                 "kl_div_loss": args.alpha,
                                 "cls_loss": 1 - args.alpha,
                             },
                             mode="weighted_sum",
                         ),
                         loaders="train",
                     ),
                     AccuracyCallback(input_key="s_logits",
                                      target_key="targets"),
                     OptimizerCallback(metric_key="loss", model_key="student"),
                     SchedulerCallback(),
                 ],
                 valid_loader="valid",
                 num_epochs=args.num_epochs,
                 criterion=torch.nn.CrossEntropyLoss(),
                 seed=args.seed)
select_last_hidden_states = HiddenStatesSelectCallback(layers=[2, 3])
kl_callback = KLDivCallback()

runner = DistilRunner(output_hidden_states=True)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-2)

runner.train(
    model={
        "teacher": teacher,
        "student": student
    },
    loaders=loaders,
    optimier=optimizer,
    criterion=nn.CrossEntropyLoss(),
    callbacks=[
        AccuracyCallback(input_key="s_logits", target_key="targets"),
        CriterionCallback(input_key="s_logits"),
        mse_callback,
        select_last_hidden_states,
        kl_callback,
        MetricAggregationCallback({
            "kl_loss": 0.2,
            "mse_loss": 0.2,
            "loss": 0.6
        }),
    ],
    valid_metric="accuracy01",
    minimize_valid_metric=False,
    num_epochs=5,
)
def main(args):

    set_global_seed(args.seed)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    datasets = {
        "train":
        Wrp(
            CIFAR100(root=".",
                     train=True,
                     download=True,
                     transform=transform_train)),
        "valid":
        Wrp(CIFAR100(root=".", train=False, transform=transform_test)),
    }

    loaders = {
        k: DataLoader(v,
                      batch_size=args.batch_size,
                      shuffle=k == "train",
                      num_workers=2)
        for k, v in datasets.items()
    }
    model = NAME2MODEL[args.model](num_classes=100)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     [150, 180, 210],
                                                     gamma=0.1)

    runner = SupervisedRunner()
    runner.train(model=model,
                 loaders=loaders,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 valid_metric="accuracy",
                 minimize_valid_metric=False,
                 logdir=args.logdir,
                 callbacks=[
                     CriterionCallback(input_key="logits",
                                       target_key="targets",
                                       metric_key="loss"),
                     AccuracyCallback(input_key="logits",
                                      target_key="targets"),
                     OptimizerCallback(metric_key="loss"),
                     SchedulerCallback(),
                 ],
                 valid_loader="valid",
                 num_epochs=args.num_epochs,
                 criterion=torch.nn.CrossEntropyLoss(),
                 seed=args.seed)