def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 4
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples, num_classes) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_classes)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # model training
        runner = dl.SupervisedRunner(input_key="features",
                                     output_key="logits",
                                     target_key="targets",
                                     loss_key="loss")
        callbacks = [
            dl.BatchTransformCallback(
                transform="F.sigmoid",
                scope="on_batch_end",
                input_key="logits",
                output_key="scores",
            ),
            dl.MultilabelAccuracyCallback(input_key="scores",
                                          target_key="targets",
                                          threshold=0.5),
            dl.MultilabelPrecisionRecallF1SupportCallback(
                input_key="scores",
                target_key="targets",
                num_classes=num_classes),
        ]
        if SETTINGS.amp_required and (engine is None or not isinstance(
                engine,
            (dl.AMPEngine, dl.DataParallelAMPEngine,
             dl.DistributedDataParallelAMPEngine),
        )):
            callbacks.append(
                dl.AUCCallback(input_key="scores", target_key="targets"))
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            valid_loader="valid",
            valid_metric="accuracy",
            minimize_valid_metric=False,
            verbose=False,
            callbacks=callbacks,
        )
Beispiel #2
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        model = {"teacher": teacher, "student": student}
        criterion = {"cls": nn.CrossEntropyLoss(), "kl": nn.KLDivLoss(reduction="batchmean")}
        optimizer = optim.Adam(student.parameters(), lr=0.02)

        loaders = {
            "train": DataLoader(
                MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
            ),
            "valid": DataLoader(
                MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
            ),
        }

        runner = DistilRunner()
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            logdir=logdir,
            verbose=False,
            callbacks=[
                dl.AccuracyCallback(
                    input_key="t_logits", target_key="targets", num_classes=2, prefix="teacher_"
                ),
                dl.AccuracyCallback(
                    input_key="s_logits", target_key="targets", num_classes=2, prefix="student_"
                ),
                dl.CriterionCallback(
                    input_key="s_logits",
                    target_key="targets",
                    metric_key="cls_loss",
                    criterion_key="cls",
                ),
                dl.CriterionCallback(
                    input_key="s_logprobs",
                    target_key="t_probs",
                    metric_key="kl_div_loss",
                    criterion_key="kl",
                ),
                dl.MetricAggregationCallback(
                    metric_key="loss", metrics=["kl_div_loss", "cls_loss"], mode="mean"
                ),
                dl.OptimizerCallback(metric_key="loss", model_key="student"),
                dl.CheckpointCallback(
                    logdir=logdir,
                    loader_key="valid",
                    metric_key="loss",
                    minimize=True,
                    save_n_best=3,
                ),
            ],
        )
Beispiel #3
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        # data
        num_samples, num_features = int(1e4), int(1e1)
        X, y = torch.rand(num_samples, num_features), torch.rand(num_samples)
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, 1)
        criterion = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6])

        # model training
        runner = dl.SupervisedRunner()
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir=logdir,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
            num_epochs=1,
            verbose=False,
        )
Beispiel #4
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:

        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        optimizer = optim.Adam(model.parameters(), lr=0.02)

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=True,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = CustomRunner()
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            optimizer=optimizer,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            verbose=False,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
        )
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 4
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples, ) * num_classes).to(torch.int64)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_classes)
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # model training
        runner = dl.SupervisedRunner(input_key="features",
                                     output_key="logits",
                                     target_key="targets",
                                     loss_key="loss")
        callbacks = [
            dl.AccuracyCallback(input_key="logits",
                                target_key="targets",
                                num_classes=num_classes),
            dl.PrecisionRecallF1SupportCallback(input_key="logits",
                                                target_key="targets",
                                                num_classes=4),
            dl.ConfusionMatrixCallback(input_key="logits",
                                       target_key="targets",
                                       num_classes=4),
        ]
        if engine is None or not isinstance(
                engine, (dl.AMPEngine, dl.DataParallelAMPEngine,
                         dl.DistributedDataParallelAMPEngine)):
            callbacks.append(
                dl.AUCCallback(input_key="logits", target_key="targets"))

        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            valid_loader="valid",
            valid_metric="accuracy03",
            minimize_valid_metric=False,
            verbose=False,
            callbacks=callbacks,
        )
Beispiel #6
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        # sample data
        num_users, num_features, num_items = int(1e4), int(1e1), 10
        X = torch.rand(num_users, num_features)
        y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_items)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        class CustomRunner(dl.Runner):
            def handle_batch(self, batch):
                x, y = batch
                logits = self.model(x)
                self.batch = {
                    "features": x,
                    "logits": logits,
                    "scores": torch.sigmoid(logits),
                    "targets": y,
                }

        # model training
        runner = CustomRunner()
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=[
                dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
                dl.AUCCallback(input_key="scores", target_key="targets"),
                dl.HitrateCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.MRRCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.NDCGCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
                dl.OptimizerCallback(metric_key="loss"),
                dl.SchedulerCallback(),
                dl.CheckpointCallback(
                    logdir=logdir, loader_key="valid", metric_key="map01", minimize=False
                ),
            ],
        )
Beispiel #7
0
        def objective(trial):
            lr = trial.suggest_loguniform("lr", 1e-3, 1e-1)
            num_hidden = int(trial.suggest_loguniform("num_hidden", 32, 128))

            loaders = {
                "train":
                DataLoader(
                    MNIST(os.getcwd(),
                          train=False,
                          download=True,
                          transform=ToTensor()),
                    batch_size=32,
                ),
                "valid":
                DataLoader(
                    MNIST(os.getcwd(),
                          train=False,
                          download=True,
                          transform=ToTensor()),
                    batch_size=32,
                ),
            }
            model = nn.Sequential(nn.Flatten(), nn.Linear(784, num_hidden),
                                  nn.ReLU(), nn.Linear(num_hidden, 10))
            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            criterion = nn.CrossEntropyLoss()

            runner = dl.SupervisedRunner(input_key="features",
                                         output_key="logits",
                                         target_key="targets")
            runner.train(
                engine=engine or dl.DeviceEngine(device),
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                loaders=loaders,
                callbacks={
                    "optuna":
                    dl.OptunaPruningCallback(loader_key="valid",
                                             metric_key="accuracy01",
                                             minimize=False,
                                             trial=trial),
                    "accuracy":
                    dl.AccuracyCallback(input_key="logits",
                                        target_key="targets",
                                        num_classes=10),
                },
                num_epochs=2,
            )
            score = runner.callbacks["optuna"].best_score
            return score
Beispiel #8
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_users, num_features, num_items = int(1e4), int(1e1), 10
        X = torch.rand(num_users, num_features)
        y = (torch.rand(num_users, num_items) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_items)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        callbacks = [
            dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
            dl.AUCCallback(input_key="scores", target_key="targets"),
            dl.HitrateCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.MRRCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.MAPCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.NDCGCallback(input_key="scores", target_key="targets", topk_args=(1, 3, 5)),
            dl.OptimizerCallback(metric_key="loss"),
            dl.SchedulerCallback(),
            dl.CheckpointCallback(
                logdir=logdir, loader_key="valid", metric_key="map01", minimize=False
            ),
        ]
        if engine is None or not isinstance(
            engine, (dl.AMPEngine, dl.DataParallelAMPEngine, dl.DistributedDataParallelAMPEngine)
        ):
            callbacks.append(dl.AUCCallback(input_key="logits", target_key="targets"))

        # model training
        runner = CustomRunner()
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=callbacks,
        )
Beispiel #9
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:

        # <--- multi-model setup --->
        encoder = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 128))
        head = nn.Linear(128, 10)
        model = {"encoder": encoder, "head": head}
        optimizer = optim.Adam([
            {
                "params": encoder.parameters()
            },
            {
                "params": head.parameters()
            },
        ],
                               lr=0.02)
        # <--- multi-model setup --->
        criterion = nn.CrossEntropyLoss()

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=True,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = CustomRunner()
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            verbose=False,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
        )
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes = int(1e4), int(1e1), 4
        X = torch.rand(num_samples, num_features)
        y = (torch.rand(num_samples, num_classes) > 0.5).to(torch.float32)

        # pytorch loaders
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = torch.nn.Linear(num_features, num_classes)
        criterion = torch.nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # model training
        runner = dl.SupervisedRunner(input_key="features",
                                     output_key="logits",
                                     target_key="targets",
                                     loss_key="loss")
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            valid_loader="valid",
            valid_metric="accuracy",
            minimize_valid_metric=False,
            verbose=False,
            callbacks=[
                dl.AUCCallback(input_key="logits", target_key="targets"),
                dl.MultilabelAccuracyCallback(input_key="logits",
                                              target_key="targets",
                                              threshold=0.5),
            ],
        )
Beispiel #11
0
        n_step=1,
        gamma=gamma,
        history_len=1,
    )

    network, target_network = get_network(env), get_network(env)
    utils.set_requires_grad(target_network, requires_grad=False)
    models = {"origin": network, "target": target_network}
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    loaders = {"train_game": DataLoader(replay_buffer, batch_size=batch_size)}

    runner = CustomRunner(gamma=gamma, tau=tau, tau_period=tau_period)
    runner.train(
        # for simplicity reasons, let's run everything on single gpu
        engine=dl.DeviceEngine("cuda"),
        model=models,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir="./logs_dqn",
        num_epochs=50,
        verbose=True,
        valid_loader="_epoch_",
        valid_metric="reward",
        minimize_valid_metric=False,
        load_best_on_end=True,
        callbacks=[
            GameCallback(
                sampler_fn=Sampler,
                env=env,
Beispiel #12
0
    network, target_network = get_network(env), get_network(env)
    utils.set_requires_grad(target_network, requires_grad=False)
    models = {"origin": network, "target": target_network}
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    loaders = {
        "train_game":
        DataLoader(
            ReplayDataset(replay_buffer, epoch_size=epoch_size),
            batch_size=batch_size,
        ),
    }

    runner = CustomRunner(gamma=gamma, tau=tau, tau_period=tau_period)
    runner.train(
        engine=dl.DeviceEngine(
            "cpu"),  # for simplicity reasons, let's run everything on cpu
        model=models,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir="./logs_dqn",
        num_epochs=10,
        verbose=True,
        valid_loader="_epoch_",
        valid_metric="v_reward",
        minimize_valid_metric=False,
        load_best_on_end=True,
        callbacks=[
            GameCallback(
                env=env,
                replay_buffer=replay_buffer,
def train_experiment(device, engine=None):

    with TemporaryDirectory() as logdir:

        # 1. data and transforms

        transforms = Compose([
            torchvision.transforms.ToPILImage(),
            torchvision.transforms.RandomCrop((28, 28)),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            Normalize((0.1307, ), (0.3081, )),
        ])

        transform_original = Compose([
            ToTensor(),
            Normalize((0.1307, ), (0.3081, )),
        ])

        mnist = MNIST("./logdir", train=True, download=True, transform=None)
        contrastive_mnist = SelfSupervisedDatasetWrapper(
            mnist,
            transforms=transforms,
            transform_original=transform_original)
        train_loader = torch.utils.data.DataLoader(contrastive_mnist,
                                                   batch_size=BATCH_SIZE)

        mnist_valid = MNIST("./logdir",
                            train=False,
                            download=True,
                            transform=None)
        contrastive_valid = SelfSupervisedDatasetWrapper(
            mnist_valid,
            transforms=transforms,
            transform_original=transform_original)
        valid_loader = torch.utils.data.DataLoader(contrastive_valid,
                                                   batch_size=BATCH_SIZE)

        # 2. model and optimizer
        encoder = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 16),
                                nn.LeakyReLU(inplace=True))
        projection_head = nn.Sequential(
            nn.Linear(16, 16, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(16, 16, bias=True),
        )

        class ContrastiveModel(torch.nn.Module):
            def __init__(self, model, encoder):
                super(ContrastiveModel, self).__init__()
                self.model = model
                self.encoder = encoder

            def forward(self, x):
                emb = self.encoder(x)
                projection = self.model(emb)
                return emb, projection

        model = ContrastiveModel(model=projection_head, encoder=encoder)

        optimizer = Adam(model.parameters(), lr=LR)

        # 3. criterion with triplets sampling
        criterion = NTXentLoss(tau=0.1)

        callbacks = [
            dl.ControlFlowCallback(
                dl.CriterionCallback(input_key="projection_left",
                                     target_key="projection_right",
                                     metric_key="loss"),
                loaders="train",
            ),
            dl.SklearnModelCallback(
                feature_key="embedding_left",
                target_key="target",
                train_loader="train",
                valid_loaders="valid",
                model_fn=RandomForestClassifier,
                predict_method="predict_proba",
                predict_key="sklearn_predict",
                random_state=RANDOM_STATE,
                n_estimators=50,
            ),
            dl.ControlFlowCallback(
                dl.AccuracyCallback(target_key="target",
                                    input_key="sklearn_predict",
                                    topk_args=(1, 3)),
                loaders="valid",
            ),
        ]

        runner = dl.SelfSupervisedRunner()

        logdir = "./logdir"
        runner.train(
            model=model,
            engine=engine or dl.DeviceEngine(device),
            criterion=criterion,
            optimizer=optimizer,
            callbacks=callbacks,
            loaders={
                "train": train_loader,
                "valid": valid_loader
            },
            verbose=False,
            logdir=logdir,
            valid_loader="train",
            valid_metric="loss",
            minimize_valid_metric=True,
            num_epochs=TRAIN_EPOCH,
        )

        valid_path = Path(logdir) / "logs/valid.csv"
        best_accuracy = max(
            float(row["accuracy"]) for row in read_csv(valid_path)
            if row["accuracy"] != "accuracy")

        assert best_accuracy > 0.6
Beispiel #14
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # latent_dim = 128
        # generator = nn.Sequential(
        #     # We want to generate 128 coefficients to reshape into a 7x7x128 map
        #     nn.Linear(128, 128 * 7 * 7),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
        #     nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Conv2d(128, 1, (7, 7), padding=3),
        #     nn.Sigmoid(),
        # )
        # discriminator = nn.Sequential(
        #     nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     GlobalMaxPool2d(),
        #     Flatten(),
        #     nn.Linear(128, 1),
        # )
        latent_dim = 32
        generator = nn.Sequential(
            nn.Linear(latent_dim, 28 * 28),
            Lambda(_ddp_hack),
            nn.Sigmoid(),
        )
        discriminator = nn.Sequential(Flatten(), nn.Linear(28 * 28, 1))

        model = {"generator": generator, "discriminator": discriminator}
        criterion = {
            "generator": nn.BCEWithLogitsLoss(),
            "discriminator": nn.BCEWithLogitsLoss()
        }
        optimizer = {
            "generator":
            torch.optim.Adam(generator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
            "discriminator":
            torch.optim.Adam(discriminator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
        }
        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = CustomRunner(latent_dim)
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.CriterionCallback(
                    input_key="combined_predictions",
                    target_key="labels",
                    metric_key="loss_discriminator",
                    criterion_key="discriminator",
                ),
                dl.CriterionCallback(
                    input_key="generated_predictions",
                    target_key="misleading_labels",
                    metric_key="loss_generator",
                    criterion_key="generator",
                ),
                dl.OptimizerCallback(
                    model_key="generator",
                    optimizer_key="generator",
                    metric_key="loss_generator",
                ),
                dl.OptimizerCallback(
                    model_key="discriminator",
                    optimizer_key="discriminator",
                    metric_key="loss_discriminator",
                ),
            ],
            valid_loader="train",
            valid_metric="loss_generator",
            minimize_valid_metric=True,
            num_epochs=1,
            verbose=False,
            logdir=logdir,
        )
        if not isinstance(engine, dl.DistributedDataParallelEngine):
            runner.predict_batch(None)[0, 0].cpu().numpy()
Beispiel #15
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.02)

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = dl.SupervisedRunner(input_key="features",
                                     output_key="logits",
                                     target_key="targets",
                                     loss_key="loss")
        callbacks = [
            dl.AccuracyCallback(input_key="logits",
                                target_key="targets",
                                topk_args=(1, 3, 5)),
            dl.PrecisionRecallF1SupportCallback(input_key="logits",
                                                target_key="targets",
                                                num_classes=10),
        ]
        if SETTINGS.ml_required:
            callbacks.append(
                dl.ConfusionMatrixCallback(input_key="logits",
                                           target_key="targets",
                                           num_classes=10))
        if SETTINGS.amp_required and (engine is None or not isinstance(
                engine,
            (dl.AMPEngine, dl.DataParallelAMPEngine,
             dl.DistributedDataParallelAMPEngine),
        )):
            callbacks.append(
                dl.AUCCallback(input_key="logits", target_key="targets"))
        if SETTINGS.onnx_required:
            callbacks.append(
                dl.OnnxCallback(logdir=logdir, input_key="features"))
        if SETTINGS.pruning_required:
            callbacks.append(
                dl.PruningCallback(pruning_fn="l1_unstructured", amount=0.5))
        if SETTINGS.quantization_required:
            callbacks.append(dl.QuantizationCallback(logdir=logdir))
        if engine is None or not isinstance(engine,
                                            dl.DistributedDataParallelEngine):
            callbacks.append(
                dl.TracingCallback(logdir=logdir, input_key="features"))
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            callbacks=callbacks,
            logdir=logdir,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
            verbose=False,
            load_best_on_end=True,
            timeit=False,
            check=False,
            overfit=False,
            fp16=False,
            ddp=False,
        )
        # model inference
        for prediction in runner.predict_loader(loader=loaders["valid"]):
            assert prediction["logits"].detach().cpu().numpy().shape[-1] == 10
        # model post-processing
        features_batch = next(iter(loaders["valid"]))[0]
        # model stochastic weight averaging
        model.load_state_dict(
            utils.get_averaged_weights_by_path_mask(logdir=logdir,
                                                    path_mask="*.pth"))
        # model onnx export
        if SETTINGS.onnx_required:
            utils.onnx_export(
                model=runner.model,
                batch=runner.engine.sync_device(features_batch),
                file="./mnist.onnx",
                verbose=False,
            )
        # model quantization
        if SETTINGS.quantization_required:
            utils.quantize_model(model=runner.model)
        # model pruning
        if SETTINGS.pruning_required:
            utils.prune_model(model=runner.model,
                              pruning_fn="l1_unstructured",
                              amount=0.8)
        # model tracing
        utils.trace_model(model=runner.model, batch=features_batch)
Beispiel #16
0
 def get_engine(self):
     return dl.DeviceEngine(self._device)
Beispiel #17
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:

        # 1. train and valid loaders
        transforms = Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))])

        train_dataset = datasets.MnistMLDataset(root=os.getcwd(),
                                                download=True,
                                                transform=transforms)
        sampler = data.BatchBalanceClassSampler(
            labels=train_dataset.get_labels(),
            num_classes=5,
            num_samples=10,
            num_batches=10)
        train_loader = DataLoader(dataset=train_dataset, batch_sampler=sampler)

        valid_dataset = datasets.MnistQGDataset(root=os.getcwd(),
                                                transform=transforms,
                                                gallery_fraq=0.2)
        valid_loader = DataLoader(dataset=valid_dataset, batch_size=1024)

        # 2. model and optimizer
        model = models.MnistSimpleNet(out_features=16)
        optimizer = Adam(model.parameters(), lr=0.001)

        # 3. criterion with triplets sampling
        sampler_inbatch = data.HardTripletsSampler(norm_required=False)
        criterion = nn.TripletMarginLossWithSampler(
            margin=0.5, sampler_inbatch=sampler_inbatch)

        # 4. training with catalyst Runner
        callbacks = [
            dl.ControlFlowCallback(
                dl.CriterionCallback(input_key="embeddings",
                                     target_key="targets",
                                     metric_key="loss"),
                loaders="train",
            ),
            dl.ControlFlowCallback(
                dl.CMCScoreCallback(
                    embeddings_key="embeddings",
                    labels_key="targets",
                    is_query_key="is_query",
                    topk_args=[1],
                ),
                loaders="valid",
            ),
            dl.PeriodicLoaderCallback(valid_loader_key="valid",
                                      valid_metric_key="cmc01",
                                      minimize=False,
                                      valid=2),
        ]

        runner = CustomRunner(input_key="features", output_key="embeddings")
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            callbacks=callbacks,
            loaders={
                "train": train_loader,
                "valid": valid_loader
            },
            verbose=False,
            logdir=logdir,
            valid_loader="valid",
            valid_metric="cmc01",
            minimize_valid_metric=False,
            num_epochs=2,
        )
Beispiel #18
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        latent_dim = 128
        generator = nn.Sequential(
            # We want to generate 128 coefficients to reshape into a 7x7x128 map
            nn.Linear(128, 128 * 7 * 7),
            nn.LeakyReLU(0.2, inplace=True),
            Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
            nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, (7, 7), padding=3),
            nn.Sigmoid(),
        )
        discriminator = nn.Sequential(
            nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            GlobalMaxPool2d(),
            Flatten(),
            nn.Linear(128, 1),
        )

        model = {"generator": generator, "discriminator": discriminator}
        criterion = {
            "generator": nn.BCEWithLogitsLoss(),
            "discriminator": nn.BCEWithLogitsLoss()
        }
        optimizer = {
            "generator":
            torch.optim.Adam(generator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
            "discriminator":
            torch.optim.Adam(discriminator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
        }
        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        class CustomRunner(dl.Runner):
            def predict_batch(self, batch):
                batch_size = 1
                # Sample random points in the latent space
                random_latent_vectors = torch.randn(batch_size,
                                                    latent_dim).to(self.device)
                # Decode them to fake images
                generated_images = self.model["generator"](
                    random_latent_vectors).detach()
                return generated_images

            def handle_batch(self, batch):
                real_images, _ = batch
                batch_size = real_images.shape[0]

                # Sample random points in the latent space
                random_latent_vectors = torch.randn(batch_size,
                                                    latent_dim).to(self.device)

                # Decode them to fake images
                generated_images = self.model["generator"](
                    random_latent_vectors).detach()
                # Combine them with real images
                combined_images = torch.cat([generated_images, real_images])

                # Assemble labels discriminating real from fake images
                labels = torch.cat([
                    torch.ones((batch_size, 1)),
                    torch.zeros((batch_size, 1))
                ]).to(self.device)
                # Add random noise to the labels - important trick!
                labels += 0.05 * torch.rand(labels.shape).to(self.device)

                # Discriminator forward
                combined_predictions = self.model["discriminator"](
                    combined_images)

                # Sample random points in the latent space
                random_latent_vectors = torch.randn(batch_size,
                                                    latent_dim).to(self.device)
                # Assemble labels that say "all real images"
                misleading_labels = torch.zeros(
                    (batch_size, 1)).to(self.device)

                # Generator forward
                generated_images = self.model["generator"](
                    random_latent_vectors)
                generated_predictions = self.model["discriminator"](
                    generated_images)

                self.batch = {
                    "combined_predictions": combined_predictions,
                    "labels": labels,
                    "generated_predictions": generated_predictions,
                    "misleading_labels": misleading_labels,
                }

        runner = CustomRunner()
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.CriterionCallback(
                    input_key="combined_predictions",
                    target_key="labels",
                    metric_key="loss_discriminator",
                    criterion_key="discriminator",
                ),
                dl.CriterionCallback(
                    input_key="generated_predictions",
                    target_key="misleading_labels",
                    metric_key="loss_generator",
                    criterion_key="generator",
                ),
                dl.OptimizerCallback(
                    model_key="generator",
                    optimizer_key="generator",
                    metric_key="loss_generator",
                ),
                dl.OptimizerCallback(
                    model_key="discriminator",
                    optimizer_key="discriminator",
                    metric_key="loss_discriminator",
                ),
            ],
            valid_loader="train",
            valid_metric="loss_generator",
            minimize_valid_metric=True,
            num_epochs=1,
            verbose=False,
            logdir=logdir,
        )
        runner.predict_batch(None)[0, 0].cpu().numpy()
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        from catalyst import utils

        utils.set_global_seed(RANDOM_STATE)
        # 1. train, valid and test loaders
        transforms = Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))])

        train_data = MNIST(os.getcwd(),
                           train=True,
                           download=True,
                           transform=transforms)
        train_labels = train_data.targets.cpu().numpy().tolist()
        train_sampler = data.BatchBalanceClassSampler(train_labels,
                                                      num_classes=10,
                                                      num_samples=4)
        train_loader = DataLoader(train_data, batch_sampler=train_sampler)

        valid_dataset = MNIST(root=os.getcwd(),
                              transform=transforms,
                              train=False,
                              download=True)
        valid_loader = DataLoader(dataset=valid_dataset, batch_size=32)

        test_dataset = MNIST(root=os.getcwd(),
                             transform=transforms,
                             train=False,
                             download=True)
        test_loader = DataLoader(dataset=test_dataset, batch_size=32)

        # 2. model and optimizer
        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 16),
                              nn.LeakyReLU(inplace=True))
        optimizer = Adam(model.parameters(), lr=LR)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # 3. criterion with triplets sampling
        sampler_inbatch = data.HardTripletsSampler(norm_required=False)
        criterion = nn.TripletMarginLossWithSampler(
            margin=0.5, sampler_inbatch=sampler_inbatch)

        # 4. training with catalyst Runner
        class CustomRunner(dl.SupervisedRunner):
            def handle_batch(self, batch) -> None:
                images, targets = batch["features"].float(
                ), batch["targets"].long()
                features = self.model(images)
                self.batch = {
                    "embeddings": features,
                    "targets": targets,
                }

        callbacks = [
            dl.ControlFlowCallback(
                dl.CriterionCallback(input_key="embeddings",
                                     target_key="targets",
                                     metric_key="loss"),
                loaders="train",
            ),
            dl.SklearnModelCallback(
                feature_key="embeddings",
                target_key="targets",
                train_loader="train",
                valid_loaders=["valid", "infer"],
                model_fn=RandomForestClassifier,
                predict_method="predict_proba",
                predict_key="sklearn_predict",
                random_state=RANDOM_STATE,
                n_estimators=50,
            ),
            dl.ControlFlowCallback(
                dl.AccuracyCallback(target_key="targets",
                                    input_key="sklearn_predict",
                                    topk_args=(1, 3)),
                loaders=["valid", "infer"],
            ),
        ]

        runner = CustomRunner(input_key="features", output_key="embeddings")
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders={
                "train": train_loader,
                "valid": valid_loader,
                "infer": test_loader
            },
            verbose=False,
            valid_loader="valid",
            valid_metric="accuracy",
            minimize_valid_metric=False,
            num_epochs=TRAIN_EPOCH,
            logdir=logdir,
        )

        valid_path = Path(logdir) / "logs/infer.csv"
        best_accuracy = max(
            float(row["accuracy"]) for row in read_csv(valid_path))

        assert best_accuracy > 0.8
def train_experiment(device):
    with TemporaryDirectory() as logdir:

        # <--- multi-model setup --->
        encoder = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 128))
        head = nn.Linear(128, 10)
        model = {"encoder": encoder, "head": head}
        optimizer = optim.Adam(
            [{"params": encoder.parameters()}, {"params": head.parameters()},], lr=0.02
        )
        # <--- multi-model setup --->
        criterion = nn.CrossEntropyLoss()

        loaders = {
            "train": DataLoader(
                MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
            ),
            "valid": DataLoader(
                MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
            ),
        }

        class CustomRunner(dl.Runner):
            def predict_batch(self, batch):
                # model inference step
                return self.model(batch[0].to(self.device))

            def on_loader_start(self, runner):
                super().on_loader_start(runner)
                self.meters = {
                    key: metrics.AdditiveValueMetric(compute_on_call=False)
                    for key in ["loss", "accuracy01", "accuracy03"]
                }

            def handle_batch(self, batch):
                # model train/valid step
                # unpack the batch
                x, y = batch
                # <--- multi-model usage --->
                # run model forward pass
                x_ = self.model["encoder"](x)
                logits = self.model["head"](x_)
                # <--- multi-model usage --->
                # compute the loss
                loss = self.criterion(logits, y)
                # compute other metrics of interest
                accuracy01, accuracy03 = metrics.accuracy(logits, y, topk=(1, 3))
                # log metrics
                self.batch_metrics.update(
                    {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03}
                )
                for key in ["loss", "accuracy01", "accuracy03"]:
                    self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)
                # run model backward pass
                if self.is_train_loader:
                    loss.backward()
                    self.optimizer.step()
                    self.optimizer.zero_grad()

            def on_loader_end(self, runner):
                for key in ["loss", "accuracy01", "accuracy03"]:
                    self.loader_metrics[key] = self.meters[key].compute()[0]
                super().on_loader_end(runner)

        runner = CustomRunner()
        # model training
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            verbose=True,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
        )
Beispiel #21
0
def main(args):
    train_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=True, transform=ToTensor())
    )
    val_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=False, transform=ToTensor())
    )

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64)
    loaders = {"train": train_dataloader, "valid": val_dataloader}
    utils.set_global_seed(args.seed)
    net = nn.Sequential(
        Flatten(),
        nn.Linear(28 * 28, 300),
        nn.ReLU(),
        nn.Linear(300, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
    )
    initial_state_dict = net.state_dict()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    if args.device is not None:
        engine = dl.DeviceEngine(args.device)
    else:
        engine = None
    if args.vanilla_pruning:
        runner = dl.SupervisedRunner(engine=engine)

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            ],
            logdir="./logdir",
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )
        pruning_fn = partial(
            utils.pruning.prune_model,
            pruning_fn=args.pruning_method,
            amount=args.amount,
            keys_to_prune=["weights"],
            dim=args.dim,
            l_norm=args.n,
        )
        acc, amount = validate_model(
            runner, pruning_fn=pruning_fn, loader=loaders["valid"], num_sessions=args.num_sessions
        )
        torch.save(acc, "accuracy.pth")
        torch.save(amount, "amount.pth")

    else:
        runner = PruneRunner(num_sessions=args.num_sessions, engine=engine)
        callbacks = [
            dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            dl.PruningCallback(
                args.pruning_method,
                keys_to_prune=["weight"],
                amount=args.amount,
                remove_reparametrization_on_stage_end=False,
            ),
            dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
        ]
        if args.lottery_ticket:
            callbacks.append(LotteryTicketCallback(initial_state_dict=initial_state_dict))
        if args.kd:
            net.load_state_dict(torch.load(args.state_dict))
            callbacks.append(
                PrepareForFinePruningCallback(probability_shift=args.probability_shift)
            )
            callbacks.append(KLDivCallback(temperature=4, student_logits_key="logits"))
            callbacks.append(
                MetricAggregationCallback(
                    prefix="loss", metrics={"loss": 0.1, "kl_div_loss": 0.9}, mode="weighted_sum"
                )
            )

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=callbacks,
            logdir=args.logdir,
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )
Beispiel #22
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # sample data
        num_samples, num_features, num_classes1, num_classes2 = int(1e4), int(
            1e1), 4, 10
        X = torch.rand(num_samples, num_features)
        y1 = (torch.rand(num_samples, ) * num_classes1).to(torch.int64)
        y2 = (torch.rand(num_samples, ) * num_classes2).to(torch.int64)

        # pytorch loaders
        dataset = TensorDataset(X, y1, y2)
        loader = DataLoader(dataset, batch_size=32, num_workers=1)
        loaders = {"train": loader, "valid": loader}

        # model, criterion, optimizer, scheduler
        model = CustomModule(num_features, num_classes1, num_classes2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters())
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # model training
        runner = CustomRunner()
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            loaders=loaders,
            num_epochs=1,
            verbose=False,
            callbacks=[
                dl.CriterionCallback(metric_key="loss1",
                                     input_key="logits1",
                                     target_key="targets1"),
                dl.CriterionCallback(metric_key="loss2",
                                     input_key="logits2",
                                     target_key="targets2"),
                dl.MetricAggregationCallback(metric_key="loss",
                                             metrics=["loss1", "loss2"],
                                             mode="mean"),
                dl.OptimizerCallback(metric_key="loss"),
                dl.SchedulerCallback(),
                dl.AccuracyCallback(
                    input_key="logits1",
                    target_key="targets1",
                    num_classes=num_classes1,
                    prefix="one_",
                ),
                dl.AccuracyCallback(
                    input_key="logits2",
                    target_key="targets2",
                    num_classes=num_classes2,
                    prefix="two_",
                ),
                dl.ConfusionMatrixCallback(
                    input_key="logits1",
                    target_key="targets1",
                    num_classes=num_classes1,
                    prefix="one_cm",
                ),
                # catalyst[ml] required
                dl.ConfusionMatrixCallback(
                    input_key="logits2",
                    target_key="targets2",
                    num_classes=num_classes2,
                    prefix="two_cm",
                ),
                # catalyst[ml] required
                dl.CheckpointCallback(
                    "./logs/one",
                    loader_key="valid",
                    metric_key="one_accuracy",
                    minimize=False,
                    save_n_best=1,
                ),
                dl.CheckpointCallback(
                    "./logs/two",
                    loader_key="valid",
                    metric_key="two_accuracy03",
                    minimize=False,
                    save_n_best=3,
                ),
            ],
            loggers={
                "console": dl.ConsoleLogger(),
                "tb": dl.TensorboardLogger("./logs/tb")
            },
        )
Beispiel #23
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:
        teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        criterion = {
            "cls": nn.CrossEntropyLoss(),
            "kl": nn.KLDivLoss(reduction="batchmean")
        }
        optimizer = optim.Adam(student.parameters(), lr=0.02)

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=True,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        class DistilRunner(dl.Runner):
            def handle_batch(self, batch):
                x, y = batch

                teacher.eval()  # let's manually set teacher model to eval mode
                with torch.no_grad():
                    t_logits = self.model["teacher"](x)

                s_logits = self.model["student"](x)
                self.batch = {
                    "t_logits": t_logits,
                    "s_logits": s_logits,
                    "targets": y,
                    "s_logprobs": F.log_softmax(s_logits, dim=-1),
                    "t_probs": F.softmax(t_logits, dim=-1),
                }

        runner = DistilRunner()
        # model training
        runner.train(
            engine=dl.DeviceEngine(device),
            model={
                "teacher": teacher,
                "student": student
            },
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            logdir=logdir,
            verbose=True,
            callbacks=[
                dl.AccuracyCallback(input_key="t_logits",
                                    target_key="targets",
                                    num_classes=2,
                                    prefix="teacher_"),
                dl.AccuracyCallback(input_key="s_logits",
                                    target_key="targets",
                                    num_classes=2,
                                    prefix="student_"),
                dl.CriterionCallback(
                    input_key="s_logits",
                    target_key="targets",
                    metric_key="cls_loss",
                    criterion_key="cls",
                ),
                dl.CriterionCallback(
                    input_key="s_logprobs",
                    target_key="t_probs",
                    metric_key="kl_div_loss",
                    criterion_key="kl",
                ),
                dl.MetricAggregationCallback(
                    prefix="loss",
                    metrics=["kl_div_loss", "cls_loss"],
                    mode="mean"),
                dl.OptimizerCallback(metric_key="loss", model_key="student"),
                dl.CheckpointCallback(
                    logdir=logdir,
                    loader_key="valid",
                    metric_key="loss",
                    minimize=True,
                    save_n_best=3,
                ),
            ],
        )
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        from catalyst import utils

        utils.set_global_seed(RANDOM_STATE)
        # 1. generate data
        num_samples, num_features, num_classes = int(1e4), int(30), 3
        X, y = make_classification(
            n_samples=num_samples,
            n_features=num_features,
            n_informative=num_features,
            n_repeated=0,
            n_redundant=0,
            n_classes=num_classes,
            n_clusters_per_class=1,
        )
        X, y = torch.tensor(X), torch.tensor(y)
        dataset = TensorDataset(X, y)
        loader = DataLoader(dataset,
                            batch_size=64,
                            num_workers=1,
                            shuffle=True)

        # 2. model, optimizer and scheduler
        hidden_size, out_features = 20, 16
        model = nn.Sequential(nn.Linear(num_features, hidden_size), nn.ReLU(),
                              nn.Linear(hidden_size, out_features))
        optimizer = Adam(model.parameters(), lr=LR)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # 3. criterion with triplets sampling
        sampler_inbatch = data.HardTripletsSampler(norm_required=False)
        criterion = nn.TripletMarginLossWithSampler(
            margin=0.5, sampler_inbatch=sampler_inbatch)

        # 4. training with catalyst Runner
        class CustomRunner(dl.SupervisedRunner):
            def handle_batch(self, batch) -> None:
                features, targets = batch["features"].float(
                ), batch["targets"].long()
                embeddings = self.model(features)
                self.batch = {
                    "embeddings": embeddings,
                    "targets": targets,
                }

        callbacks = [
            dl.SklearnModelCallback(
                feature_key="embeddings",
                target_key="targets",
                train_loader="train",
                valid_loaders="valid",
                model_fn=RandomForestClassifier,
                predict_method="predict_proba",
                predict_key="sklearn_predict",
                random_state=RANDOM_STATE,
                n_estimators=100,
            ),
            dl.ControlFlowCallback(
                dl.AccuracyCallback(target_key="targets",
                                    input_key="sklearn_predict",
                                    topk_args=(1, 3)),
                loaders="valid",
            ),
        ]

        runner = CustomRunner(input_key="features", output_key="embeddings")
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            callbacks=callbacks,
            scheduler=scheduler,
            loaders={
                "train": loader,
                "valid": loader
            },
            verbose=False,
            valid_loader="valid",
            valid_metric="accuracy",
            minimize_valid_metric=False,
            num_epochs=TRAIN_EPOCH,
            logdir=logdir,
        )

        valid_path = Path(logdir) / "logs/valid.csv"
        best_accuracy = max(
            float(row["accuracy"]) for row in read_csv(valid_path))

        assert best_accuracy > 0.9