Ejemplo n.º 1
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        model = {"teacher": teacher, "student": student}
        criterion = {"cls": nn.CrossEntropyLoss(), "kl": nn.KLDivLoss(reduction="batchmean")}
        optimizer = optim.Adam(student.parameters(), lr=0.02)

        loaders = {
            "train": DataLoader(
                MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
            ),
            "valid": DataLoader(
                MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
            ),
        }

        runner = DistilRunner()
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            logdir=logdir,
            verbose=False,
            callbacks=[
                dl.AccuracyCallback(
                    input_key="t_logits", target_key="targets", num_classes=2, prefix="teacher_"
                ),
                dl.AccuracyCallback(
                    input_key="s_logits", target_key="targets", num_classes=2, prefix="student_"
                ),
                dl.CriterionCallback(
                    input_key="s_logits",
                    target_key="targets",
                    metric_key="cls_loss",
                    criterion_key="cls",
                ),
                dl.CriterionCallback(
                    input_key="s_logprobs",
                    target_key="t_probs",
                    metric_key="kl_div_loss",
                    criterion_key="kl",
                ),
                dl.MetricAggregationCallback(
                    metric_key="loss", metrics=["kl_div_loss", "cls_loss"], mode="mean"
                ),
                dl.OptimizerCallback(metric_key="loss", model_key="student"),
                dl.CheckpointCallback(
                    logdir=logdir,
                    loader_key="valid",
                    metric_key="loss",
                    minimize=True,
                    save_n_best=3,
                ),
            ],
        )
Ejemplo n.º 2
0
def test_aggregation_2():
    """
    Aggregation with custom function
    """
    loaders, model, criterion, optimizer = prepare_experiment()
    runner = dl.SupervisedRunner()

    def aggregation_function(metrics, runner):
        epoch = runner.stage_epoch_step
        loss = (3 / 2 - epoch / 2) * metrics["loss_focal"] + (1 / 2 * epoch - 1 / 2) * metrics[
            "loss_bce"
        ]
        return loss

    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir="./logs/aggregation_2/",
        num_epochs=3,
        callbacks=[
            dl.CriterionCallback(
                input_key="logits",
                target_key="targets",
                metric_key="loss_bce",
                criterion_key="bce",
            ),
            dl.CriterionCallback(
                input_key="logits",
                target_key="targets",
                metric_key="loss_focal",
                criterion_key="focal",
            ),
            # loss aggregation
            dl.MetricAggregationCallback(metric_key="loss", mode=aggregation_function),
        ],
    )
    for loader in ["train", "valid"]:
        metrics = runner.epoch_metrics[loader]
        loss_1 = metrics["loss_bce"]
        loss_2 = metrics["loss"]
        assert np.abs(loss_1 - loss_2) < 1e-5
Ejemplo n.º 3
0
def test_aggregation_1():
    """
    Aggregation as weighted_sum
    """
    loaders, model, criterion, optimizer = prepare_experiment()
    runner = dl.SupervisedRunner()
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir="./logs/aggregation_1/",
        num_epochs=3,
        callbacks=[
            dl.CriterionCallback(
                input_key="logits",
                target_key="targets",
                metric_key="loss_bce",
                criterion_key="bce",
            ),
            dl.CriterionCallback(
                input_key="logits",
                target_key="targets",
                metric_key="loss_focal",
                criterion_key="focal",
            ),
            # loss aggregation
            dl.MetricAggregationCallback(
                metric_key="loss",
                metrics={
                    "loss_focal": 0.6,
                    "loss_bce": 0.4
                },
                mode="weighted_sum",
            ),
        ],
    )
    for loader in ["train", "valid"]:
        metrics = runner.epoch_metrics[loader]
        loss_1 = metrics["loss_bce"] * 0.4 + metrics["loss_focal"] * 0.6
        loss_2 = metrics["loss"]
        assert np.abs(loss_1 - loss_2) < 1e-5
Ejemplo n.º 4
0
def main(args):
    wandb.init(project="teacher-pruning", config=vars(args))
    set_global_seed(42)
    # dataloader initialization
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = Wrp(
        datasets.CIFAR10(root=os.getcwd(),
                         train=True,
                         transform=transform_train,
                         download=True))
    valid_dataset = Wrp(
        datasets.CIFAR10(root=os.getcwd(),
                         train=False,
                         transform=transform_test))
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=2)
    valid_dataloader = DataLoader(dataset=valid_dataset,
                                  batch_size=128,
                                  num_workers=2)
    loaders = {
        "train": train_dataloader,
        "valid": valid_dataloader,
    }
    # model initialization
    model = PreActResNet18()
    model.fc = nn.Linear(512, 10)
    if args.teacher_model is not None:
        is_kd = True
        teacher_model = NAME2MODEL[args.teacher_model]()
        load_model_from_path(model=teacher_model, path=args.teacher_path)
        model = {
            "student": model,
            "teacher": teacher_model,
        }
        output_hiddens = args.beta is None
        is_kd_on_hiddens = output_hiddens
        runner = KDRunner(device=args.device, output_hiddens=output_hiddens)
        parameters = model["student"].parameters()
    else:
        is_kd = False
        runner = dl.SupervisedRunner(device=args.device)
        parameters = model.parameters()
    # optimizer
    optimizer_cls = NAME2OPTIM[args.optimizer]
    optimizer_kwargs = {"params": parameters, "lr": args.lr}
    if args.optimizer == "sgd":
        optimizer_kwargs["momentum"] = args.momentum
    else:
        optimizer_kwargs["betas"] = (args.beta1, args.beta2)
    optimizer = optimizer_cls(**optimizer_kwargs)
    scheduler = MultiStepLR(optimizer, milestones=[80, 120], gamma=args.gamma)
    logdir = f"logs/{wandb.run.name}"
    # callbacks
    callbacks = [dl.AccuracyCallback(num_classes=10), WandbCallback()]
    if is_kd:
        metrics = {}
        callbacks.append(dl.CriterionCallback(output_key="cls_loss"))
        callbacks.append(DiffOutputCallback())
        coefs = get_loss_coefs(args.alpha, args.beta)
        metrics["cls_loss"] = coefs[0]
        metrics["diff_output_loss"] = coefs[1]
        if is_kd_on_hiddens:
            callbacks.append(DiffHiddenCallback())
            metrics["diff_hidden_loss"] = coefs[2]

        aggregator_callback = dl.MetricAggregationCallback(prefix="loss",
                                                           metrics=metrics,
                                                           mode="weighted_sum")
        wrapped_agg_callback = dl.ControlFlowCallback(aggregator_callback,
                                                      loaders=["train"])
        callbacks.append(wrapped_agg_callback)

    runner.train(
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=nn.CrossEntropyLoss(),
        loaders=loaders,
        callbacks=callbacks,
        num_epochs=args.epoch,
        logdir=logdir,
        verbose=True,
    )
Ejemplo n.º 5
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes1, num_classes2 = int(1e4), int(
            1e1), 4, 10
        X = torch.rand(num_samples, num_features)
        y1 = (torch.rand(num_samples, ) * num_classes1).to(torch.int64)
        y2 = (torch.rand(num_samples, ) * num_classes2).to(torch.int64)

        # pytorch loaders
        dataset = TensorDataset(X, y1, y2)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        class CustomModule(nn.Module):
            def __init__(self, in_features: int, out_features1: int,
                         out_features2: int):
                super().__init__()
                self.shared = nn.Linear(in_features, 128)
                self.head1 = nn.Linear(128, out_features1)
                self.head2 = nn.Linear(128, out_features2)

            def forward(self, x):
                x = self.shared(x)
                y1 = self.head1(x)
                y2 = self.head2(x)
                return y1, y2

        # model, criterion, optimizer, scheduler
        model = CustomModule(num_features, num_classes1, num_classes2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters())
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2])

        class CustomRunner(dl.Runner):
            def handle_batch(self, batch):
                x, y1, y2 = batch
                y1_hat, y2_hat = self.model(x)
                self.batch = {
                    "features": x,
                    "logits1": y1_hat,
                    "logits2": y2_hat,
                    "targets1": y1,
                    "targets2": y2,
                }

        # model training
        runner = CustomRunner()
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=[
                dl.CriterionCallback(metric_key="loss1",
                                     input_key="logits1",
                                     target_key="targets1"),
                dl.CriterionCallback(metric_key="loss2",
                                     input_key="logits2",
                                     target_key="targets2"),
                dl.MetricAggregationCallback(prefix="loss",
                                             metrics=["loss1", "loss2"],
                                             mode="mean"),
                dl.OptimizerCallback(metric_key="loss"),
                dl.SchedulerCallback(),
                dl.AccuracyCallback(
                    input_key="logits1",
                    target_key="targets1",
                    num_classes=num_classes1,
                    prefix="one_",
                ),
                dl.AccuracyCallback(
                    input_key="logits2",
                    target_key="targets2",
                    num_classes=num_classes2,
                    prefix="two_",
                ),
                dl.ConfusionMatrixCallback(
                    input_key="logits1",
                    target_key="targets1",
                    num_classes=num_classes1,
                    prefix="one_cm",
                ),
                # catalyst[ml] required
                dl.ConfusionMatrixCallback(
                    input_key="logits2",
                    target_key="targets2",
                    num_classes=num_classes2,
                    prefix="two_cm",
                ),
                # catalyst[ml] required
                dl.CheckpointCallback(
                    "./logs/one",
                    loader_key="valid",
                    metric_key="one_accuracy",
                    minimize=False,
                    save_n_best=1,
                ),
                dl.CheckpointCallback(
                    "./logs/two",
                    loader_key="valid",
                    metric_key="two_accuracy03",
                    minimize=False,
                    save_n_best=3,
                ),
            ],
            loggers={
                "console": dl.ConsoleLogger(),
                "tb": dl.TensorboardLogger("./logs/tb")
            },
        )
def train_experiment(engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes1, num_classes2 = int(1e4), int(
            1e1), 4, 10
        X = torch.rand(num_samples, num_features)
        y1 = (torch.rand(num_samples) * num_classes1).to(torch.int64)
        y2 = (torch.rand(num_samples) * num_classes2).to(torch.int64)

        # pytorch loaders
        dataset = TensorDataset(X, y1, y2)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = CustomModule(num_features, num_classes1, num_classes2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters())
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2])

        callbacks = [
            dl.CriterionCallback(metric_key="loss1",
                                 input_key="logits1",
                                 target_key="targets1"),
            dl.CriterionCallback(metric_key="loss2",
                                 input_key="logits2",
                                 target_key="targets2"),
            dl.MetricAggregationCallback(metric_key="loss",
                                         metrics=["loss1", "loss2"],
                                         mode="mean"),
            dl.BackwardCallback(metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
            dl.SchedulerCallback(),
            dl.AccuracyCallback(
                input_key="logits1",
                target_key="targets1",
                num_classes=num_classes1,
                prefix="one_",
            ),
            dl.AccuracyCallback(
                input_key="logits2",
                target_key="targets2",
                num_classes=num_classes2,
                prefix="two_",
            ),
            dl.CheckpointCallback(
                "./logs/one",
                loader_key="valid",
                metric_key="one_accuracy01",
                minimize=False,
                topk=1,
            ),
            dl.CheckpointCallback(
                "./logs/two",
                loader_key="valid",
                metric_key="two_accuracy03",
                minimize=False,
                topk=3,
            ),
        ]
        if SETTINGS.ml_required:
            # catalyst[ml] required
            callbacks.append(
                dl.ConfusionMatrixCallback(
                    input_key="logits1",
                    target_key="targets1",
                    num_classes=num_classes1,
                    prefix="one_cm",
                ))
            # catalyst[ml] required
            callbacks.append(
                dl.ConfusionMatrixCallback(
                    input_key="logits2",
                    target_key="targets2",
                    num_classes=num_classes2,
                    prefix="two_cm",
                ))

        # model training
        runner = CustomRunner()
        runner.train(
            engine=engine,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=callbacks,
            loggers={
                "console": dl.ConsoleLogger(),
                "tb": dl.TensorboardLogger("./logs/tb"),
            },
        )
Ejemplo n.º 7
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        criterion = {
            "cls": nn.CrossEntropyLoss(),
            "kl": nn.KLDivLoss(reduction="batchmean")
        }
        optimizer = optim.Adam(student.parameters(), lr=0.02)

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=True,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        class DistilRunner(dl.Runner):
            def handle_batch(self, batch):
                x, y = batch

                teacher.eval()  # let's manually set teacher model to eval mode
                with torch.no_grad():
                    t_logits = self.model["teacher"](x)

                s_logits = self.model["student"](x)
                self.batch = {
                    "t_logits": t_logits,
                    "s_logits": s_logits,
                    "targets": y,
                    "s_logprobs": F.log_softmax(s_logits, dim=-1),
                    "t_probs": F.softmax(t_logits, dim=-1),
                }

        runner = DistilRunner()
        # model training
        runner.train(
            engine=dl.DeviceEngine(device),
            model={
                "teacher": teacher,
                "student": student
            },
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            logdir=logdir,
            verbose=True,
            callbacks=[
                dl.AccuracyCallback(input_key="t_logits",
                                    target_key="targets",
                                    num_classes=2,
                                    prefix="teacher_"),
                dl.AccuracyCallback(input_key="s_logits",
                                    target_key="targets",
                                    num_classes=2,
                                    prefix="student_"),
                dl.CriterionCallback(
                    input_key="s_logits",
                    target_key="targets",
                    metric_key="cls_loss",
                    criterion_key="cls",
                ),
                dl.CriterionCallback(
                    input_key="s_logprobs",
                    target_key="t_probs",
                    metric_key="kl_div_loss",
                    criterion_key="kl",
                ),
                dl.MetricAggregationCallback(
                    prefix="loss",
                    metrics=["kl_div_loss", "cls_loss"],
                    mode="mean"),
                dl.OptimizerCallback(metric_key="loss", model_key="student"),
                dl.CheckpointCallback(
                    logdir=logdir,
                    loader_key="valid",
                    metric_key="loss",
                    minimize=True,
                    save_n_best=3,
                ),
            ],
        )