Exemple #1
0
def main():
    generator = nn.Sequential(
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        nn.Linear(128, 128 * 7 * 7),
        nn.LeakyReLU(0.2, inplace=True),
        Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
        nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(128, 1, (7, 7), padding=3),
        nn.Sigmoid(),
    )
    discriminator = nn.Sequential(
        nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        GlobalMaxPool2d(),
        Flatten(),
        nn.Linear(128, 1),
    )

    model = {"generator": generator, "discriminator": discriminator}
    optimizer = {
        "generator": torch.optim.Adam(
            generator.parameters(), lr=0.0003, betas=(0.5, 0.999)
        ),
        "discriminator": torch.optim.Adam(
            discriminator.parameters(), lr=0.0003, betas=(0.5, 0.999)
        ),
    }
    loaders = {
        "train": DataLoader(
            MNIST(
                os.getcwd(), train=True, download=True, transform=ToTensor(),
            ),
            batch_size=32,
        ),
    }

    runner = CustomRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=loaders,
        callbacks=[
            dl.OptimizerCallback(
                optimizer_key="generator", metric_key="loss_generator"
            ),
            dl.OptimizerCallback(
                optimizer_key="discriminator", metric_key="loss_discriminator"
            ),
        ],
        main_metric="loss_generator",
        num_epochs=20,
        verbose=True,
        logdir="./logs_gan",
        check=True,
    )
Exemple #2
0
 def get_callbacks(self, stage: str):
     callbacks = {
         "criterion":
         dl.CriterionCallback(metric_key="loss",
                              input_key="logits",
                              target_key="targets"),
         "optimizer":
         dl.OptimizerCallback(
             metric_key="loss",
             grad_clip_fn=nn.utils.clip_grad_norm_,
             grad_clip_params={"max_norm": 1.0},
         ),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "accuracy":
         dl.AccuracyCallback(input_key="logits",
                             target_key="targets",
                             topk_args=(1, 3, 5)),
         "classification":
         dl.PrecisionRecallF1SupportCallback(input_key="logits",
                                             target_key="targets",
                                             num_classes=10),
         "checkpoint":
         dl.CheckpointCallback(self._logdir,
                               loader_key="valid",
                               metric_key="loss",
                               minimize=True,
                               save_n_best=3),
     }
     if SETTINGS.ml_required:
         callbacks["confusion_matrix"] = dl.ConfusionMatrixCallback(
             input_key="logits", target_key="targets", num_classes=10)
     return callbacks
Exemple #3
0
 def get_callbacks(self, stage: str):
     return {
         "optimizer": dl.OptimizerCallback(metric_key="loss"),
         "checkpoint": dl.CheckpointCallback(
             self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3
         ),
     }
Exemple #4
0
 def get_callbacks(self, stage: str):
     return {
         "criterion":
         dl.CriterionCallback(input_key="logits",
                              target_key="labels",
                              metric_key="loss"),
         "optimizer":
         dl.OptimizerCallback(metric_key="loss"),
         "scheduler":
         dl.SchedulerCallback(loader_key="valid",
                              metric_key="loss",
                              mode="batch"),
         "accuracy":
         dl.AccuracyCallback(input_key="logits",
                             target_key="labels",
                             topk_args=(1, )),
         "checkpoint":
         dl.CheckpointCallback(
             self._logdir,
             loader_key="valid",
             metric_key="accuracy",
             minimize=False,
             save_n_best=1,
         ),
         # "tqdm": dl.TqdmCallback(),
     }
Exemple #5
0
 def get_callbacks(self, stage: str):
     return {
         "criterion":
         dl.CriterionCallback(metric_key="loss",
                              input_key="logits",
                              target_key="targets"),
         "optimizer":
         dl.OptimizerCallback(metric_key="loss"),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "accuracy":
         dl.AccuracyCallback(input_key="logits",
                             target_key="targets",
                             topk_args=(1, 3, 5)),
         "classification":
         dl.PrecisionRecallF1SupportCallback(input_key="logits",
                                             target_key="targets",
                                             num_classes=10),
         "confusion_matrix":
         dl.ConfusionMatrixCallback(input_key="logits",
                                    target_key="targets",
                                    num_classes=10),
         "checkpoint":
         dl.CheckpointCallback(self._logdir,
                               loader_key="valid",
                               metric_key="loss",
                               minimize=True,
                               save_n_best=3),
     }
Exemple #6
0
 def get_callbacks(self):
     return {
         "criterion":
         dl.CriterionCallback(metric_key="loss",
                              input_key="logits",
                              target_key="targets"),
         "backward":
         dl.BackwardCallback(metric_key="loss"),
         "optimizer":
         dl.OptimizerCallback(metric_key="loss"),
         "scheduler":
         dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "accuracy":
         dl.AccuracyCallback(input_key="logits",
                             target_key="targets",
                             topk=(1, 3, 5)),
         "checkpoint":
         dl.CheckpointCallback(
             self._logdir,
             loader_key="valid",
             metric_key="accuracy01",
             minimize=False,
             topk=1,
         ),
         "tqdm":
         dl.TqdmCallback(),
     }
Exemple #7
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        model = {"teacher": teacher, "student": student}
        criterion = {"cls": nn.CrossEntropyLoss(), "kl": nn.KLDivLoss(reduction="batchmean")}
        optimizer = optim.Adam(student.parameters(), lr=0.02)

        loaders = {
            "train": DataLoader(
                MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
            ),
            "valid": DataLoader(
                MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
            ),
        }

        runner = DistilRunner()
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            logdir=logdir,
            verbose=False,
            callbacks=[
                dl.AccuracyCallback(
                    input_key="t_logits", target_key="targets", num_classes=2, prefix="teacher_"
                ),
                dl.AccuracyCallback(
                    input_key="s_logits", target_key="targets", num_classes=2, prefix="student_"
                ),
                dl.CriterionCallback(
                    input_key="s_logits",
                    target_key="targets",
                    metric_key="cls_loss",
                    criterion_key="cls",
                ),
                dl.CriterionCallback(
                    input_key="s_logprobs",
                    target_key="t_probs",
                    metric_key="kl_div_loss",
                    criterion_key="kl",
                ),
                dl.MetricAggregationCallback(
                    metric_key="loss", metrics=["kl_div_loss", "cls_loss"], mode="mean"
                ),
                dl.OptimizerCallback(metric_key="loss", model_key="student"),
                dl.CheckpointCallback(
                    logdir=logdir,
                    loader_key="valid",
                    metric_key="loss",
                    minimize=True,
                    save_n_best=3,
                ),
            ],
        )
 def get_callbacks(self, stage: str):
     callbacks = {
         "scores":
         dl.BatchTransformCallback(
             input_key="logits",
             output_key="scores",
             transform=partial(torch.softmax, dim=1),
             scope="on_batch_end",
         ),
         "labels":
         dl.BatchTransformCallback(
             input_key="scores",
             output_key="labels",
             transform=partial(torch.argmax, dim=1),
             scope="on_batch_end",
         ),
         "criterion":
         dl.CriterionCallback(metric_key="loss",
                              input_key="logits",
                              target_key="targets"),
         "optimizer":
         dl.OptimizerCallback(
             metric_key="loss",
             grad_clip_fn=nn.utils.clip_grad_norm_,
             grad_clip_params={"max_norm": 1.0},
         ),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "accuracy":
         dl.AccuracyCallback(input_key="logits",
                             target_key="targets",
                             topk_args=(1, 3, 5)),
         "classification":
         dl.PrecisionRecallF1SupportCallback(input_key="logits",
                                             target_key="targets",
                                             num_classes=10),
         "checkpoint":
         dl.CheckpointCallback(self._logdir,
                               loader_key="valid",
                               metric_key="loss",
                               minimize=True,
                               save_n_best=3),
     }
     if SETTINGS.ml_required:
         callbacks["confusion_matrix"] = dl.ConfusionMatrixCallback(
             input_key="logits", target_key="targets", num_classes=10)
         callbacks["f1_score"] = dl.SklearnBatchCallback(
             keys={
                 "y_pred": "labels",
                 "y_true": "targets"
             },
             metric_fn="f1_score",
             metric_key="sk_f1",
             average="macro",
             zero_division=1,
         )
     return callbacks
Exemple #9
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        # sample data
        num_users, num_features, num_items = int(1e4), int(1e1), 10
        X = torch.rand(num_users, num_features)
        y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_items)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        class CustomRunner(dl.Runner):
            def handle_batch(self, batch):
                x, y = batch
                logits = self.model(x)
                self.batch = {
                    "features": x,
                    "logits": logits,
                    "scores": torch.sigmoid(logits),
                    "targets": y,
                }

        # model training
        runner = CustomRunner()
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=[
                dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
                dl.AUCCallback(input_key="scores", target_key="targets"),
                dl.HitrateCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.MRRCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.NDCGCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.OptimizerCallback(metric_key="loss"),
                dl.SchedulerCallback(),
                dl.CheckpointCallback(
                    logdir=logdir, loader_key="valid", metric_key="map01", minimize=False
                ),
            ],
        )
Exemple #10
0
 def get_callbacks(self, stage: str) -> Dict[str, dl.Callback]:
     return {
         "criterion": dl.CriterionCallback(
             metric_key="loss", input_key="logits", target_key="targets"
         ),
         "optimizer": dl.OptimizerCallback(metric_key="loss"),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "checkpoint": dl.CheckpointCallback(
             self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3
         ),
         "check_freezed": CheckRequiresGrad("layer1", "train_freezed", False),
         "check_unfreezed": CheckRequiresGrad("layer1", "train_unfreezed", True),
     }
def test():
    """Test Notebook API"""
    dataset = MelFromDisk(path="data/test")
    dataloader = torch.utils.data.DataLoader(dataset)
    loaders = OrderedDict({"train": dataloader})
    generator = Generator(80)
    discriminator = Discriminator()

    model = torch.nn.ModuleDict({
        "generator": generator,
        "discriminator": discriminator
    })
    optimizer = {
        "opt_g": torch.optim.Adam(generator.parameters()),
        "opt_d": torch.optim.Adam(discriminator.parameters()),
    }
    callbacks = {
        "loss_g":
        GeneratorLossCallback(),
        "loss_d":
        DiscriminatorLossCallback(),
        "o_g":
        dl.OptimizerCallback(metric_key="generator_loss",
                             optimizer_key="opt_g"),
        "o_d":
        dl.OptimizerCallback(metric_key="discriminator_loss",
                             optimizer_key="opt_d"),
    }
    runner = MelGANRunner()

    runner.train(
        model=model,
        loaders=loaders,
        optimizer=optimizer,
        callbacks=callbacks,
        check=True,
        main_metric="discriminator_loss",
    )
Exemple #12
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_users, num_features, num_items = int(1e4), int(1e1), 10
        X = torch.rand(num_users, num_features)
        y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_items)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        callbacks = [
            dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
            dl.AUCCallback(input_key="scores", target_key="targets"),
            dl.HitrateCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.MRRCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.NDCGCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.OptimizerCallback(metric_key="loss"),
            dl.SchedulerCallback(),
            dl.CheckpointCallback(
                logdir=logdir, loader_key="valid", metric_key="map01", minimize=False
            ),
        ]
        if engine is None or not isinstance(
            engine, (dl.AMPEngine, dl.DataParallelAMPEngine, dl.DistributedDataParallelAMPEngine)
        ):
            callbacks.append(dl.AUCCallback(input_key="logits", target_key="targets"))

        # model training
        runner = CustomRunner()
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=callbacks,
        )
Exemple #13
0
 def get_callbacks(self):
     return {
         "backward":
         dl.BackwardCallback(metric_key="loss"),
         "optimizer":
         dl.OptimizerCallback(metric_key="loss"),
         "checkpoint":
         dl.CheckpointCallback(
             self._logdir,
             loader_key="valid",
             metric_key="loss",
             minimize=True,
             topk=3,
         ),
     }
Exemple #14
0
 def get_callbacks(self, stage: str):
     return {
         "criterion": dl.CriterionCallback(
             metric_key="loss", input_key="logits", target_key="targets"
         ),
         "optimizer": dl.OptimizerCallback(metric_key="loss"),
         "checkpoint": dl.CheckpointCallback(
             self._logdir,
             loader_key="valid",
             metric_key="loss",
             minimize=True,
             save_n_best=3,
             load_on_stage_start="best",
         ),
         "test_model_load": CheckModelStateLoadAfterStages("second", self._logdir, "best.pth"),
     }
def test_is_running():
    """Test if perplexity is running normal"""
    tok = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    model = AutoModelWithLMHead.from_pretrained("distilbert-base-uncased")
    dataset = LanguageModelingDataset(texts, tok)
    collate_fn = DataCollatorForLanguageModeling(tok).collate_batch
    dataloader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

    runner = HuggingFaceRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders={"train": dataloader},
        callbacks={
            "optimizer": dl.OptimizerCallback(),
            "perplexity": PerplexityMetricCallback(),
        },
        check=True,
    )
Exemple #16
0
 def get_callbacks(self, stage: str):
     return {
         "criterion": dl.CriterionCallback(
             metric_key="loss", input_key="logits", target_key="targets"
         ),
         "optimizer": dl.OptimizerCallback(metric_key="loss"),
         "profiler": ProfilerCallback(
             loader_key="train",
             epoch=1,
             profiler_kwargs=dict(
                 activities=[
                     torch.profiler.ProfilerActivity.CPU,
                     torch.profiler.ProfilerActivity.CUDA,
                 ],
                 with_stack=True,
                 with_flops=True,
             ),
             tensorboard_path=self.profiler_tb_logs,
             export_chrome_trace_path=self.chrome_trace_logs,
             export_stacks_kwargs=self._export_stacks_kwargs,
         ),
     }
        in_size=DATASETS[args.dataset]["in_size"],
        in_channels=DATASETS[args.dataset]["in_channels"],
        feature_dim=args.feature_dim,
    )
    optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-6)

    # define criterion
    criterion = BarlowTwinsLoss(offdiag_lambda=args.offdig_lambda)

    # and callbacks
    callbacks = [
        dl.CriterionCallback(
            input_key="projection_left", target_key="projection_right", metric_key="loss"
        ),
        dl.BackwardCallback(metric_key="loss"),
        dl.OptimizerCallback(metric_key="loss"),
        dl.SklearnModelCallback(
            feature_key="embedding_origin",
            target_key="target",
            train_loader="train",
            valid_loaders="valid",
            model_fn=LogisticRegression,
            predict_key="sklearn_predict",
            predict_method="predict_proba",
            C=0.1,
            solver="saga",
            max_iter=200,
        ),
        dl.ControlFlowCallbackWrapper(
            dl.AccuracyCallback(
                target_key="target", input_key="sklearn_predict", topk=(1, 3)
Exemple #18
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        latent_dim = 128
        generator = nn.Sequential(
            # We want to generate 128 coefficients to reshape into a 7x7x128 map
            nn.Linear(128, 128 * 7 * 7),
            nn.LeakyReLU(0.2, inplace=True),
            Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
            nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, (7, 7), padding=3),
            nn.Sigmoid(),
        )
        discriminator = nn.Sequential(
            nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            GlobalMaxPool2d(),
            Flatten(),
            nn.Linear(128, 1),
        )

        model = {"generator": generator, "discriminator": discriminator}
        criterion = {
            "generator": nn.BCEWithLogitsLoss(),
            "discriminator": nn.BCEWithLogitsLoss()
        }
        optimizer = {
            "generator":
            torch.optim.Adam(generator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
            "discriminator":
            torch.optim.Adam(discriminator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
        }
        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        class CustomRunner(dl.Runner):
            def predict_batch(self, batch):
                batch_size = 1
                # Sample random points in the latent space
                random_latent_vectors = torch.randn(batch_size,
                                                    latent_dim).to(self.device)
                # Decode them to fake images
                generated_images = self.model["generator"](
                    random_latent_vectors).detach()
                return generated_images

            def handle_batch(self, batch):
                real_images, _ = batch
                batch_size = real_images.shape[0]

                # Sample random points in the latent space
                random_latent_vectors = torch.randn(batch_size,
                                                    latent_dim).to(self.device)

                # Decode them to fake images
                generated_images = self.model["generator"](
                    random_latent_vectors).detach()
                # Combine them with real images
                combined_images = torch.cat([generated_images, real_images])

                # Assemble labels discriminating real from fake images
                labels = torch.cat([
                    torch.ones((batch_size, 1)),
                    torch.zeros((batch_size, 1))
                ]).to(self.device)
                # Add random noise to the labels - important trick!
                labels += 0.05 * torch.rand(labels.shape).to(self.device)

                # Discriminator forward
                combined_predictions = self.model["discriminator"](
                    combined_images)

                # Sample random points in the latent space
                random_latent_vectors = torch.randn(batch_size,
                                                    latent_dim).to(self.device)
                # Assemble labels that say "all real images"
                misleading_labels = torch.zeros(
                    (batch_size, 1)).to(self.device)

                # Generator forward
                generated_images = self.model["generator"](
                    random_latent_vectors)
                generated_predictions = self.model["discriminator"](
                    generated_images)

                self.batch = {
                    "combined_predictions": combined_predictions,
                    "labels": labels,
                    "generated_predictions": generated_predictions,
                    "misleading_labels": misleading_labels,
                }

        runner = CustomRunner()
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.CriterionCallback(
                    input_key="combined_predictions",
                    target_key="labels",
                    metric_key="loss_discriminator",
                    criterion_key="discriminator",
                ),
                dl.CriterionCallback(
                    input_key="generated_predictions",
                    target_key="misleading_labels",
                    metric_key="loss_generator",
                    criterion_key="generator",
                ),
                dl.OptimizerCallback(
                    model_key="generator",
                    optimizer_key="generator",
                    metric_key="loss_generator",
                ),
                dl.OptimizerCallback(
                    model_key="discriminator",
                    optimizer_key="discriminator",
                    metric_key="loss_discriminator",
                ),
            ],
            valid_loader="train",
            valid_metric="loss_generator",
            minimize_valid_metric=True,
            num_epochs=1,
            verbose=False,
            logdir=logdir,
        )
        runner.predict_batch(None)[0, 0].cpu().numpy()
Exemple #19
0
def train_experiment(engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_users, num_features, num_items = int(1e4), int(1e1), 10
        X = torch.rand(num_users, num_features)
        y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_items)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        callbacks = [
            dl.BatchTransformCallback(
                input_key="logits",
                output_key="scores",
                transform=torch.sigmoid,
                scope="on_batch_end",
            ),
            dl.CriterionCallback(input_key="logits",
                                 target_key="targets",
                                 metric_key="loss"),
            dl.HitrateCallback(input_key="scores",
                               target_key="targets",
                               topk=(1, 3, 5)),
            dl.MRRCallback(input_key="scores",
                           target_key="targets",
                           topk=(1, 3, 5)),
            dl.MAPCallback(input_key="scores",
                           target_key="targets",
                           topk=(1, 3, 5)),
            dl.NDCGCallback(input_key="scores",
                            target_key="targets",
                            topk=(1, 3)),
            dl.BackwardCallback(metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
            dl.SchedulerCallback(),
            dl.CheckpointCallback(logdir=logdir,
                                  loader_key="valid",
                                  metric_key="map01",
                                  minimize=False),
        ]
        if isinstance(engine, dl.CPUEngine):
            callbacks.append(
                dl.AUCCallback(input_key="logits", target_key="targets"))

        # model training
        runner = dl.SupervisedRunner(
            input_key="features",
            output_key="logits",
            target_key="targets",
            loss_key="loss",
        )
        runner.train(
            engine=engine,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=callbacks,
        )
Exemple #20
0
    item_num = len(train_dataset[0])
    model = MultiVAE([200, 600, item_num], dropout=0.5)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = StepLR(optimizer, step_size=20, gamma=0.1)
    engine = dl.Engine()
    hparams = {
        "anneal_cap": 0.2,
        "total_anneal_steps": 6000,
    }
    callbacks = [
        dl.NDCGCallback("logits", "targets", [20, 50, 100]),
        dl.MAPCallback("logits", "targets", [20, 50, 100]),
        dl.MRRCallback("logits", "targets", [20, 50, 100]),
        dl.HitrateCallback("logits", "targets", [20, 50, 100]),
        dl.BackwardCallback("loss"),
        dl.OptimizerCallback("loss", accumulation_steps=1),
        dl.SchedulerCallback(),
    ]

    runner = RecSysRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        engine=engine,
        hparams=hparams,
        scheduler=lr_scheduler,
        loaders=loaders,
        num_epochs=100,
        verbose=True,
        timeit=False,
        callbacks=callbacks,
Exemple #21
0
def main(args):
    train_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=True, transform=ToTensor())
    )
    val_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=False, transform=ToTensor())
    )

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64)
    loaders = {"train": train_dataloader, "valid": val_dataloader}
    utils.set_global_seed(args.seed)
    net = nn.Sequential(
        Flatten(),
        nn.Linear(28 * 28, 300),
        nn.ReLU(),
        nn.Linear(300, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
    )
    initial_state_dict = net.state_dict()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    if args.device is not None:
        engine = dl.DeviceEngine(args.device)
    else:
        engine = None
    if args.vanilla_pruning:
        runner = dl.SupervisedRunner(engine=engine)

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            ],
            logdir="./logdir",
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )
        pruning_fn = partial(
            utils.pruning.prune_model,
            pruning_fn=args.pruning_method,
            amount=args.amount,
            keys_to_prune=["weights"],
            dim=args.dim,
            l_norm=args.n,
        )
        acc, amount = validate_model(
            runner, pruning_fn=pruning_fn, loader=loaders["valid"], num_sessions=args.num_sessions
        )
        torch.save(acc, "accuracy.pth")
        torch.save(amount, "amount.pth")

    else:
        runner = PruneRunner(num_sessions=args.num_sessions, engine=engine)
        callbacks = [
            dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            dl.PruningCallback(
                args.pruning_method,
                keys_to_prune=["weight"],
                amount=args.amount,
                remove_reparametrization_on_stage_end=False,
            ),
            dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
        ]
        if args.lottery_ticket:
            callbacks.append(LotteryTicketCallback(initial_state_dict=initial_state_dict))
        if args.kd:
            net.load_state_dict(torch.load(args.state_dict))
            callbacks.append(
                PrepareForFinePruningCallback(probability_shift=args.probability_shift)
            )
            callbacks.append(KLDivCallback(temperature=4, student_logits_key="logits"))
            callbacks.append(
                MetricAggregationCallback(
                    prefix="loss", metrics={"loss": 0.1, "kl_div_loss": 0.9}, mode="weighted_sum"
                )
            )

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=callbacks,
            logdir=args.logdir,
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )
Exemple #22
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes1, num_classes2 = int(1e4), int(
            1e1), 4, 10
        X = torch.rand(num_samples, num_features)
        y1 = (torch.rand(num_samples, ) * num_classes1).to(torch.int64)
        y2 = (torch.rand(num_samples, ) * num_classes2).to(torch.int64)

        # pytorch loaders
        dataset = TensorDataset(X, y1, y2)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        class CustomModule(nn.Module):
            def __init__(self, in_features: int, out_features1: int,
                         out_features2: int):
                super().__init__()
                self.shared = nn.Linear(in_features, 128)
                self.head1 = nn.Linear(128, out_features1)
                self.head2 = nn.Linear(128, out_features2)

            def forward(self, x):
                x = self.shared(x)
                y1 = self.head1(x)
                y2 = self.head2(x)
                return y1, y2

        # model, criterion, optimizer, scheduler
        model = CustomModule(num_features, num_classes1, num_classes2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters())
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2])

        class CustomRunner(dl.Runner):
            def handle_batch(self, batch):
                x, y1, y2 = batch
                y1_hat, y2_hat = self.model(x)
                self.batch = {
                    "features": x,
                    "logits1": y1_hat,
                    "logits2": y2_hat,
                    "targets1": y1,
                    "targets2": y2,
                }

        # model training
        runner = CustomRunner()
        runner.train(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=[
                dl.CriterionCallback(metric_key="loss1",
                                     input_key="logits1",
                                     target_key="targets1"),
                dl.CriterionCallback(metric_key="loss2",
                                     input_key="logits2",
                                     target_key="targets2"),
                dl.MetricAggregationCallback(prefix="loss",
                                             metrics=["loss1", "loss2"],
                                             mode="mean"),
                dl.OptimizerCallback(metric_key="loss"),
                dl.SchedulerCallback(),
                dl.AccuracyCallback(
                    input_key="logits1",
                    target_key="targets1",
                    num_classes=num_classes1,
                    prefix="one_",
                ),
                dl.AccuracyCallback(
                    input_key="logits2",
                    target_key="targets2",
                    num_classes=num_classes2,
                    prefix="two_",
                ),
                dl.ConfusionMatrixCallback(
                    input_key="logits1",
                    target_key="targets1",
                    num_classes=num_classes1,
                    prefix="one_cm",
                ),
                # catalyst[ml] required
                dl.ConfusionMatrixCallback(
                    input_key="logits2",
                    target_key="targets2",
                    num_classes=num_classes2,
                    prefix="two_cm",
                ),
                # catalyst[ml] required
                dl.CheckpointCallback(
                    "./logs/one",
                    loader_key="valid",
                    metric_key="one_accuracy",
                    minimize=False,
                    save_n_best=1,
                ),
                dl.CheckpointCallback(
                    "./logs/two",
                    loader_key="valid",
                    metric_key="two_accuracy03",
                    minimize=False,
                    save_n_best=3,
                ),
            ],
            loggers={
                "console": dl.ConsoleLogger(),
                "tb": dl.TensorboardLogger("./logs/tb")
            },
        )
Exemple #23
0
        total_preds = [vectorizer.devectorize(i) for i in seq]
        total_tags = [vectorizer.devectorize(i) for i in tags]

        self.input = {
            'x': sents,
            'x_char': chars,
            'y': tags,
            'total_tags': total_tags
        }  # 'mask': mask,
        self.output = {'preds': total_preds}


callbacks = {
    "optimizer":
    dl.OptimizerCallback(metric_key="loss",
                         accumulation_steps=1,
                         grad_clip_params=None),
    "criterion":
    dl.CriterionCallback(
        input_key=['x', 'x_char', 'y'],  #'mask': mask,
        output_key=[]),
    "metric":
    dl.MetricCallback(input_key='total_tags',
                      output_key='preds',
                      prefix='F1_token',
                      metric_fn=ner_token_f1),
    "checkpoints":
    CheckpointCallback(save_n_best=3),
}
"""
callbacks = [
def train_experiment(engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes1, num_classes2 = int(1e4), int(
            1e1), 4, 10
        X = torch.rand(num_samples, num_features)
        y1 = (torch.rand(num_samples) * num_classes1).to(torch.int64)
        y2 = (torch.rand(num_samples) * num_classes2).to(torch.int64)

        # pytorch loaders
        dataset = TensorDataset(X, y1, y2)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = CustomModule(num_features, num_classes1, num_classes2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters())
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2])

        callbacks = [
            dl.CriterionCallback(metric_key="loss1",
                                 input_key="logits1",
                                 target_key="targets1"),
            dl.CriterionCallback(metric_key="loss2",
                                 input_key="logits2",
                                 target_key="targets2"),
            dl.MetricAggregationCallback(metric_key="loss",
                                         metrics=["loss1", "loss2"],
                                         mode="mean"),
            dl.BackwardCallback(metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
            dl.SchedulerCallback(),
            dl.AccuracyCallback(
                input_key="logits1",
                target_key="targets1",
                num_classes=num_classes1,
                prefix="one_",
            ),
            dl.AccuracyCallback(
                input_key="logits2",
                target_key="targets2",
                num_classes=num_classes2,
                prefix="two_",
            ),
            dl.CheckpointCallback(
                "./logs/one",
                loader_key="valid",
                metric_key="one_accuracy01",
                minimize=False,
                topk=1,
            ),
            dl.CheckpointCallback(
                "./logs/two",
                loader_key="valid",
                metric_key="two_accuracy03",
                minimize=False,
                topk=3,
            ),
        ]
        if SETTINGS.ml_required:
            # catalyst[ml] required
            callbacks.append(
                dl.ConfusionMatrixCallback(
                    input_key="logits1",
                    target_key="targets1",
                    num_classes=num_classes1,
                    prefix="one_cm",
                ))
            # catalyst[ml] required
            callbacks.append(
                dl.ConfusionMatrixCallback(
                    input_key="logits2",
                    target_key="targets2",
                    num_classes=num_classes2,
                    prefix="two_cm",
                ))

        # model training
        runner = CustomRunner()
        runner.train(
            engine=engine,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=callbacks,
            loggers={
                "console": dl.ConsoleLogger(),
                "tb": dl.TensorboardLogger("./logs/tb"),
            },
        )
Exemple #25
0
                                            latent_dim).to(self.device)
        # Assemble labels that say "all real images"
        misleading_labels = torch.zeros((batch_size, 1)).to(self.device)

        # Train the generator
        generated_images = self.model["generator"](random_latent_vectors)
        predictions = self.model["discriminator"](generated_images)
        batch_metrics["loss_generator"] = F.binary_cross_entropy_with_logits(
            predictions, misleading_labels)

        self.state.batch_metrics.update(**batch_metrics)


runner = CustomRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    loaders=loaders,
    callbacks=[
        dl.OptimizerCallback(optimizer_key="generator",
                             metric_key="loss_generator"),
        dl.OptimizerCallback(optimizer_key="discriminator",
                             metric_key="loss_discriminator"),
    ],
    main_metric="loss_generator",
    num_epochs=20,
    verbose=True,
    logdir="./logs_gan",
    check=True,
)
Exemple #26
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        criterion = {
            "cls": nn.CrossEntropyLoss(),
            "kl": nn.KLDivLoss(reduction="batchmean")
        }
        optimizer = optim.Adam(student.parameters(), lr=0.02)

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=True,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        class DistilRunner(dl.Runner):
            def handle_batch(self, batch):
                x, y = batch

                teacher.eval()  # let's manually set teacher model to eval mode
                with torch.no_grad():
                    t_logits = self.model["teacher"](x)

                s_logits = self.model["student"](x)
                self.batch = {
                    "t_logits": t_logits,
                    "s_logits": s_logits,
                    "targets": y,
                    "s_logprobs": F.log_softmax(s_logits, dim=-1),
                    "t_probs": F.softmax(t_logits, dim=-1),
                }

        runner = DistilRunner()
        # model training
        runner.train(
            engine=dl.DeviceEngine(device),
            model={
                "teacher": teacher,
                "student": student
            },
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            logdir=logdir,
            verbose=True,
            callbacks=[
                dl.AccuracyCallback(input_key="t_logits",
                                    target_key="targets",
                                    num_classes=2,
                                    prefix="teacher_"),
                dl.AccuracyCallback(input_key="s_logits",
                                    target_key="targets",
                                    num_classes=2,
                                    prefix="student_"),
                dl.CriterionCallback(
                    input_key="s_logits",
                    target_key="targets",
                    metric_key="cls_loss",
                    criterion_key="cls",
                ),
                dl.CriterionCallback(
                    input_key="s_logprobs",
                    target_key="t_probs",
                    metric_key="kl_div_loss",
                    criterion_key="kl",
                ),
                dl.MetricAggregationCallback(
                    prefix="loss",
                    metrics=["kl_div_loss", "cls_loss"],
                    mode="mean"),
                dl.OptimizerCallback(metric_key="loss", model_key="student"),
                dl.CheckpointCallback(
                    logdir=logdir,
                    loader_key="valid",
                    metric_key="loss",
                    minimize=True,
                    save_n_best=3,
                ),
            ],
        )
Exemple #27
0
def test_runner():
    """Test that runner executes"""
    train_df = pd.read_csv("data/train.csv")
    valid_df = pd.read_csv("data/valid.csv")
    teacher_config = AutoConfig.from_pretrained("bert-base-uncased",
                                                output_hidden_states=True,
                                                output_logits=True)
    teacher = BertForMaskedLM.from_pretrained("bert-base-uncased",
                                              config=teacher_config)

    student_config = AutoConfig.from_pretrained(
        "distilbert-base-uncased",
        output_hidden_states=True,
        output_logits=True,
    )
    student = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased",
                                                    config=student_config)

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    train_dataset = LanguageModelingDataset(train_df["text"], tokenizer)
    valid_dataset = LanguageModelingDataset(valid_df["text"], tokenizer)

    collate_fn = DataCollatorForLanguageModeling(tokenizer)
    train_dataloader = DataLoader(train_dataset,
                                  collate_fn=collate_fn,
                                  batch_size=2)
    valid_dataloader = DataLoader(valid_dataset,
                                  collate_fn=collate_fn,
                                  batch_size=2)
    loaders = {"train": train_dataloader, "valid": valid_dataloader}

    callbacks = {
        "masked_lm_loss":
        MaskedLanguageModelCallback(),
        "mse_loss":
        MSELossCallback(),
        "cosine_loss":
        CosineLossCallback(),
        "kl_div_loss":
        KLDivLossCallback(),
        "loss":
        MetricAggregationCallback(
            prefix="loss",
            mode="weighted_sum",
            metrics={
                "cosine_loss": 1.0,
                "masked_lm_loss": 1.0,
                "kl_div_loss": 1.0,
                "mse_loss": 1.0,
            },
        ),
        "optimizer":
        dl.OptimizerCallback(),
        "perplexity":
        PerplexityMetricCallbackDistillation(),
    }

    model = torch.nn.ModuleDict({"teacher": teacher, "student": student})
    runner = DistilMLMRunner()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=loaders,
        verbose=True,
        check=True,
        callbacks=callbacks,
    )
    assert True
Exemple #28
0
def main():
    """
    Docs.
    """
    generator = nn.Sequential(nn.Linear(128, 28 * 28), nn.Tanh())
    discriminator = nn.Sequential(nn.Linear(28 * 28, 1), nn.Sigmoid())
    model = nn.ModuleDict({
        "generator": generator,
        "discriminator": discriminator
    })

    generator_optimizer = torch.optim.Adam(generator.parameters(),
                                           lr=0.0001,
                                           betas=(0.5, 0.999))
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),
                                               lr=0.0001,
                                               betas=(0.5, 0.999))
    optimizer = {
        "generator": generator_optimizer,
        "discriminator": discriminator_optimizer,
    }

    loaders = {
        "train":
        DataLoader(
            MNIST(
                os.getcwd(),
                train=False,
                download=True,
                transform=transforms.ToTensor(),
            ),
            batch_size=32,
        ),
        "valid":
        DataLoader(
            MNIST(
                os.getcwd(),
                train=False,
                download=True,
                transform=transforms.ToTensor(),
            ),
            batch_size=32,
        ),
    }

    runner = CustomRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=loaders,
        callbacks=[
            dl.OptimizerCallback(optimizer_key="generator",
                                 loss_key="loss_generator"),
            dl.OptimizerCallback(optimizer_key="discriminator",
                                 loss_key="loss_discriminator"),
        ],
        main_metric="loss_generator",
        num_epochs=5,
        logdir="./logs/gan",
        verbose=True,
        check=True,
    )
Exemple #29
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # latent_dim = 128
        # generator = nn.Sequential(
        #     # We want to generate 128 coefficients to reshape into a 7x7x128 map
        #     nn.Linear(128, 128 * 7 * 7),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
        #     nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Conv2d(128, 1, (7, 7), padding=3),
        #     nn.Sigmoid(),
        # )
        # discriminator = nn.Sequential(
        #     nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     GlobalMaxPool2d(),
        #     Flatten(),
        #     nn.Linear(128, 1),
        # )
        latent_dim = 32
        generator = nn.Sequential(
            nn.Linear(latent_dim, 28 * 28),
            Lambda(_ddp_hack),
            nn.Sigmoid(),
        )
        discriminator = nn.Sequential(Flatten(), nn.Linear(28 * 28, 1))

        model = {"generator": generator, "discriminator": discriminator}
        criterion = {
            "generator": nn.BCEWithLogitsLoss(),
            "discriminator": nn.BCEWithLogitsLoss()
        }
        optimizer = {
            "generator":
            torch.optim.Adam(generator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
            "discriminator":
            torch.optim.Adam(discriminator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
        }
        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = CustomRunner(latent_dim)
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.CriterionCallback(
                    input_key="combined_predictions",
                    target_key="labels",
                    metric_key="loss_discriminator",
                    criterion_key="discriminator",
                ),
                dl.CriterionCallback(
                    input_key="generated_predictions",
                    target_key="misleading_labels",
                    metric_key="loss_generator",
                    criterion_key="generator",
                ),
                dl.OptimizerCallback(
                    model_key="generator",
                    optimizer_key="generator",
                    metric_key="loss_generator",
                ),
                dl.OptimizerCallback(
                    model_key="discriminator",
                    optimizer_key="discriminator",
                    metric_key="loss_discriminator",
                ),
            ],
            valid_loader="train",
            valid_metric="loss_generator",
            minimize_valid_metric=True,
            num_epochs=1,
            verbose=False,
            logdir=logdir,
        )
        if not isinstance(engine, dl.DistributedDataParallelEngine):
            runner.predict_batch(None)[0, 0].cpu().numpy()
Exemple #30
0
callbacks = {
    "masked_lm_loss": MaskedLanguageModelCallback(),
    "mse_loss": MSELossCallback(),
    "cosine_loss": CosineLossCallback(),
    "kl_div_loss": KLDivLossCallback(),
    "loss": MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum",
        metrics={
            "cosine_loss": 1.0,
            "masked_lm_loss": 1.0,
            "kl_div_loss": 1.0,
            "mse_loss": 1.0
        }
    ),
    "optimizer": dl.OptimizerCallback(),
    "perplexity": PerplexityMetricCallbackDistillation()
}


runner = DistilMLMRunner(device=torch.device("cuda"))
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
runner.train(
    model=model,
    optimizer=optimizer,
    loaders=loaders,
    verbose=True,
    check=True,
    callbacks=callbacks,
)