Beispiel #1
0
def train_experiment(device):
    with TemporaryDirectory() as logdir:

        # <--- multi-model/optimizer setup --->
        encoder = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 128))
        head = nn.Linear(128, 10)
        model = {"encoder": encoder, "head": head}
        optimizer = {
            "encoder": optim.Adam(encoder.parameters(), lr=0.02),
            "head": optim.Adam(head.parameters(), lr=0.001),
        }
        # <--- multi-model/optimizer setup --->
        criterion = nn.CrossEntropyLoss()

        loaders = {
            "train": DataLoader(
                MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
            ),
            "valid": DataLoader(
                MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
            ),
        }

        class CustomRunner(dl.Runner):
            def predict_batch(self, batch):
                # model inference step
                return self.model(batch[0].to(self.device))

            def on_loader_start(self, runner):
                super().on_loader_start(runner)
                self.meters = {
                    key: metrics.AdditiveValueMetric(compute_on_call=False)
                    for key in ["loss", "accuracy01", "accuracy03"]
                }

            def handle_batch(self, batch):
                # model train/valid step
                # unpack the batch
                x, y = batch
                # <--- multi-model usage --->
                # run model forward pass
                x_ = self.model["encoder"](x)
                logits = self.model["head"](x_)
                # <--- multi-model usage --->
                # compute the loss
                loss = self.criterion(logits, y)
                # compute other metrics of interest
                accuracy01, accuracy03 = metrics.accuracy(logits, y, topk=(1, 3))
                # log metrics
                self.batch_metrics.update(
                    {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03}
                )
                for key in ["loss", "accuracy01", "accuracy03"]:
                    self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)
                # run model backward pass
                if self.is_train_loader:
                    loss.backward()
                    # <--- multi-model/optimizer usage --->
                    self.optimizer["encoder"].step()
                    self.optimizer["head"].step()
                    self.optimizer["encoder"].zero_grad()
                    self.optimizer["head"].zero_grad()
                    # <--- multi-model/optimizer usage --->

            def on_loader_end(self, runner):
                for key in ["loss", "accuracy01", "accuracy03"]:
                    self.loader_metrics[key] = self.meters[key].compute()[0]
                super().on_loader_end(runner)

        runner = CustomRunner()
        # model training
        runner.train(
            engine=dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            logdir=logdir,
            num_epochs=1,
            verbose=True,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
        )
Beispiel #2
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.02)

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = dl.SupervisedRunner(input_key="features",
                                     output_key="logits",
                                     target_key="targets",
                                     loss_key="loss")
        callbacks = [
            dl.AccuracyCallback(input_key="logits",
                                target_key="targets",
                                topk_args=(1, 3, 5)),
            dl.PrecisionRecallF1SupportCallback(input_key="logits",
                                                target_key="targets",
                                                num_classes=10),
            dl.ConfusionMatrixCallback(input_key="logits",
                                       target_key="targets",
                                       num_classes=10),
        ]
        if engine is None or not isinstance(
                engine, (dl.AMPEngine, dl.DataParallelAMPEngine,
                         dl.DistributedDataParallelAMPEngine)):
            callbacks.append(
                dl.AUCCallback(input_key="logits", target_key="targets"))
        if SETTINGS.onnx_required:
            callbacks.append(
                dl.OnnxCallback(logdir=logdir, input_key="features"))
        if SETTINGS.pruning_required:
            callbacks.append(
                dl.PruningCallback(pruning_fn="l1_unstructured", amount=0.5))
        if SETTINGS.quantization_required:
            callbacks.append(dl.QuantizationCallback(logdir=logdir))
        if engine is None or not isinstance(engine,
                                            dl.DistributedDataParallelEngine):
            callbacks.append(
                dl.TracingCallback(logdir=logdir, input_key="features"))
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            callbacks=callbacks,
            logdir=logdir,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
            verbose=False,
            load_best_on_end=True,
            timeit=False,
            check=False,
            overfit=False,
            fp16=False,
            ddp=False,
        )
        # model inference
        for prediction in runner.predict_loader(loader=loaders["valid"]):
            assert prediction["logits"].detach().cpu().numpy().shape[-1] == 10
        # model post-processing
        features_batch = next(iter(loaders["valid"]))[0]
        # model stochastic weight averaging
        model.load_state_dict(
            utils.get_averaged_weights_by_path_mask(logdir=logdir,
                                                    path_mask="*.pth"))
        # model onnx export
        if SETTINGS.onnx_required:
            utils.onnx_export(
                model=runner.model,
                batch=runner.engine.sync_device(features_batch),
                file="./mnist.onnx",
                verbose=False,
            )
        # model quantization
        if SETTINGS.quantization_required:
            utils.quantize_model(model=runner.model)
        # model pruning
        if SETTINGS.pruning_required:
            utils.prune_model(model=runner.model,
                              pruning_fn="l1_unstructured",
                              amount=0.8)
        # model tracing
        utils.trace_model(model=runner.model, batch=features_batch)
Beispiel #3
0
# The fastai DataLoader is a drop-in replacement for Pytorch's;
#   no code changes are required other than changing the import line
from fastai.data.load import DataLoader
import os,torch
from torch.nn import functional as F
from catalyst import dl
from catalyst.data.cv import ToTensor
from catalyst.contrib.datasets import MNIST
from catalyst.utils import metrics

model = torch.nn.Linear(28 * 28, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

loaders = {
    "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32),
    "valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32),
}

class CustomRunner(dl.Runner):
    def predict_batch(self, batch): return self.model(batch[0].to(self.device).view(batch[0].size(0), -1))

    def _handle_batch(self, batch):
        x, y = batch
        y_hat = self.model(x.view(x.size(0), -1))

        loss = F.cross_entropy(y_hat, y)
        accuracy01, accuracy03 = metrics.accuracy(y_hat, y, topk=(1, 3))
        self.batch_metrics.update(
            {"loss": loss, "accuracy01": accuracy01, "accuracy03": accuracy03}
        )
Beispiel #4
0
def main(args):
    train_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=True, transform=ToTensor())
    )
    val_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=False, transform=ToTensor())
    )

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64)
    loaders = {"train": train_dataloader, "valid": val_dataloader}
    utils.set_global_seed(args.seed)
    net = nn.Sequential(
        Flatten(),
        nn.Linear(28 * 28, 300),
        nn.ReLU(),
        nn.Linear(300, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
    )
    initial_state_dict = net.state_dict()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    if args.device is not None:
        engine = dl.DeviceEngine(args.device)
    else:
        engine = None
    if args.vanilla_pruning:
        runner = dl.SupervisedRunner(engine=engine)

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            ],
            logdir="./logdir",
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )
        pruning_fn = partial(
            utils.pruning.prune_model,
            pruning_fn=args.pruning_method,
            amount=args.amount,
            keys_to_prune=["weights"],
            dim=args.dim,
            l_norm=args.n,
        )
        acc, amount = validate_model(
            runner, pruning_fn=pruning_fn, loader=loaders["valid"], num_sessions=args.num_sessions
        )
        torch.save(acc, "accuracy.pth")
        torch.save(amount, "amount.pth")

    else:
        runner = PruneRunner(num_sessions=args.num_sessions, engine=engine)
        callbacks = [
            dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            dl.PruningCallback(
                args.pruning_method,
                keys_to_prune=["weight"],
                amount=args.amount,
                remove_reparametrization_on_stage_end=False,
            ),
            dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
        ]
        if args.lottery_ticket:
            callbacks.append(LotteryTicketCallback(initial_state_dict=initial_state_dict))
        if args.kd:
            net.load_state_dict(torch.load(args.state_dict))
            callbacks.append(
                PrepareForFinePruningCallback(probability_shift=args.probability_shift)
            )
            callbacks.append(KLDivCallback(temperature=4, student_logits_key="logits"))
            callbacks.append(
                MetricAggregationCallback(
                    prefix="loss", metrics={"loss": 0.1, "kl_div_loss": 0.9}, mode="weighted_sum"
                )
            )

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=callbacks,
            logdir=args.logdir,
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )
def main():
    generator = nn.Sequential(
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        nn.Linear(128, 128 * 7 * 7),
        nn.LeakyReLU(0.2, inplace=True),
        Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
        nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(128, 1, (7, 7), padding=3),
        nn.Sigmoid(),
    )
    discriminator = nn.Sequential(
        nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
        nn.LeakyReLU(0.2, inplace=True),
        GlobalMaxPool2d(),
        Flatten(),
        nn.Linear(128, 1),
    )

    model = {"generator": generator, "discriminator": discriminator}
    optimizer = {
        "generator":
        torch.optim.Adam(generator.parameters(), lr=0.0003,
                         betas=(0.5, 0.999)),
        "discriminator":
        torch.optim.Adam(discriminator.parameters(),
                         lr=0.0003,
                         betas=(0.5, 0.999)),
    }
    loaders = {
        "train":
        DataLoader(
            MNIST(
                "./data",
                train=True,
                download=True,
                transform=ToTensor(),
            ),
            batch_size=32,
        ),
    }

    runner = CustomRunner()
    runner.train(
        model=model,
        optimizer=optimizer,
        loaders=loaders,
        callbacks=[
            dl.OptimizerCallback(optimizer_key="generator",
                                 metric_key="loss_generator"),
            dl.OptimizerCallback(optimizer_key="discriminator",
                                 metric_key="loss_discriminator"),
        ],
        main_metric="loss_generator",
        num_epochs=20,
        verbose=True,
        logdir="./logs_gan",
        check=True,
    )
import collections
import torch

from catalyst.contrib.datasets import MNIST
from catalyst.data.cv import ToTensor, Compose, Normalize

bs = 32
num_workers = 0

data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

loaders = collections.OrderedDict()

trainset = MNIST(
    "./data", train=False, download=True, transform=data_transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=bs, shuffle=True, num_workers=num_workers
)

testset = MNIST("./data", train=False, download=True, transform=data_transform)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=bs, shuffle=False, num_workers=num_workers
)

loaders["train"] = trainloader
loaders["valid"] = testloader

# # Model
Beispiel #7
0
    nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
    nn.LeakyReLU(0.2, inplace=True), GlobalMaxPool2d(), Flatten(),
    nn.Linear(128, 1))

model = {"generator": generator, "discriminator": discriminator}
optimizer = {
    "generator":
    torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.5, 0.999)),
    "discriminator":
    torch.optim.Adam(discriminator.parameters(), lr=0.0003,
                     betas=(0.5, 0.999)),
}
loaders = {
    "train":
    DataLoader(MNIST(os.getcwd(),
                     train=True,
                     download=True,
                     transform=ToTensor()),
               batch_size=32),
}


class CustomRunner(dl.Runner):
    def _handle_batch(self, batch):
        real_images, _ = batch
        batch_metrics = {}

        # Sample random points in the latent space
        batch_size = real_images.shape[0]
        random_latent_vectors = torch.randn(batch_size,
                                            latent_dim).to(self.device)
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        model = {"teacher": teacher, "student": student}
        criterion = {
            "cls": nn.CrossEntropyLoss(),
            "kl": nn.KLDivLoss(reduction="batchmean")
        }
        optimizer = optim.Adam(student.parameters(), lr=0.02)

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=True,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = DistilRunner()
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            logdir=logdir,
            verbose=False,
            callbacks=[
                dl.AccuracyCallback(input_key="t_logits",
                                    target_key="targets",
                                    num_classes=2,
                                    prefix="teacher_"),
                dl.AccuracyCallback(input_key="s_logits",
                                    target_key="targets",
                                    num_classes=2,
                                    prefix="student_"),
                dl.CriterionCallback(
                    input_key="s_logits",
                    target_key="targets",
                    metric_key="cls_loss",
                    criterion_key="cls",
                ),
                dl.CriterionCallback(
                    input_key="s_logprobs",
                    target_key="t_probs",
                    metric_key="kl_div_loss",
                    criterion_key="kl",
                ),
                dl.MetricAggregationCallback(
                    metric_key="loss",
                    metrics=["kl_div_loss", "cls_loss"],
                    mode="mean"),
                dl.OptimizerCallback(metric_key="loss", model_key="student"),
                dl.CheckpointCallback(
                    logdir=logdir,
                    loader_key="valid",
                    metric_key="loss",
                    minimize=True,
                    save_n_best=3,
                ),
            ],
        )
Beispiel #9
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        # latent_dim = 128
        # generator = nn.Sequential(
        #     # We want to generate 128 coefficients to reshape into a 7x7x128 map
        #     nn.Linear(128, 128 * 7 * 7),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     Lambda(lambda x: x.view(x.size(0), 128, 7, 7)),
        #     nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.ConvTranspose2d(128, 128, (4, 4), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Conv2d(128, 1, (7, 7), padding=3),
        #     nn.Sigmoid(),
        # )
        # discriminator = nn.Sequential(
        #     nn.Conv2d(1, 64, (3, 3), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=1),
        #     nn.LeakyReLU(0.2, inplace=True),
        #     GlobalMaxPool2d(),
        #     Flatten(),
        #     nn.Linear(128, 1),
        # )
        latent_dim = 32
        generator = nn.Sequential(
            nn.Linear(latent_dim, 28 * 28),
            Lambda(_ddp_hack),
            nn.Sigmoid(),
        )
        discriminator = nn.Sequential(Flatten(), nn.Linear(28 * 28, 1))

        model = {"generator": generator, "discriminator": discriminator}
        criterion = {
            "generator": nn.BCEWithLogitsLoss(),
            "discriminator": nn.BCEWithLogitsLoss()
        }
        optimizer = {
            "generator":
            torch.optim.Adam(generator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
            "discriminator":
            torch.optim.Adam(discriminator.parameters(),
                             lr=0.0003,
                             betas=(0.5, 0.999)),
        }
        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = CustomRunner(latent_dim)
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.CriterionCallback(
                    input_key="combined_predictions",
                    target_key="labels",
                    metric_key="loss_discriminator",
                    criterion_key="discriminator",
                ),
                dl.CriterionCallback(
                    input_key="generated_predictions",
                    target_key="misleading_labels",
                    metric_key="loss_generator",
                    criterion_key="generator",
                ),
                dl.OptimizerCallback(
                    model_key="generator",
                    optimizer_key="generator",
                    metric_key="loss_generator",
                ),
                dl.OptimizerCallback(
                    model_key="discriminator",
                    optimizer_key="discriminator",
                    metric_key="loss_discriminator",
                ),
            ],
            valid_loader="train",
            valid_metric="loss_generator",
            minimize_valid_metric=True,
            num_epochs=1,
            verbose=False,
            logdir=logdir,
        )
        if not isinstance(engine, dl.DistributedDataParallelEngine):
            runner.predict_batch(None)[0, 0].cpu().numpy()
Beispiel #10
0
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader

from catalyst import dl, metrics
from catalyst.contrib.data.cv import ToTensor
from catalyst.contrib.datasets import MNIST

model = torch.nn.Linear(28 * 28, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

loaders = {
    "train":
    DataLoader(
        MNIST("./data", train=True, download=True, transform=ToTensor()),
        batch_size=32,
    ),
}


class CustomRunner(dl.Runner):
    def predict_batch(self, batch):
        # model inference step
        return self.model(batch[0].to(self.device).view(batch[0].size(0), -1))

    def _handle_batch(self, batch):
        # model train/valid step
        x, y = batch
        y_hat = self.model(x.view(x.size(0), -1))
from itertools import chain

from catalyst.callbacks import AccuracyCallback
from catalyst.contrib.datasets import MNIST
import torch
from torch.utils.data import DataLoader

from compressors.distillation.runners import EndToEndDistilRunner
from compressors.models import MLP
from compressors.utils.data import TorchvisionDatasetWrapper as Wrp

teacher = MLP(num_layers=4)
student = MLP(num_layers=3)

datasets = {
    "train": Wrp(MNIST("./data", train=True, download=True)),
    "valid": Wrp(MNIST("./data", train=False)),
}

loaders = {
    dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32)
    for dl_key, dataset in datasets.item()
}

optimizer = torch.optim.Adam(chain(teacher.parameters(), student.parameters()))

runner = EndToEndDistilRunner(hidden_state_loss="mse",
                              num_train_teacher_epochs=5)

runner.train(
    model={
def train_experiment(engine=None):
    with TemporaryDirectory() as logdir:
        utils.set_global_seed(RANDOM_STATE)
        # 1. train, valid and test loaders
        train_data = MNIST(DATA_ROOT, train=True)
        train_labels = train_data.targets.cpu().numpy().tolist()
        train_sampler = BatchBalanceClassSampler(train_labels,
                                                 num_classes=10,
                                                 num_samples=4)
        train_loader = DataLoader(train_data, batch_sampler=train_sampler)

        valid_dataset = MNIST(root=DATA_ROOT, train=False)
        valid_loader = DataLoader(dataset=valid_dataset, batch_size=32)

        test_dataset = MNIST(root=DATA_ROOT, train=False)
        test_loader = DataLoader(dataset=test_dataset, batch_size=32)

        # 2. model and optimizer
        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 16),
                              nn.LeakyReLU(inplace=True))
        optimizer = Adam(model.parameters(), lr=LR)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])

        # 3. criterion with triplets sampling
        sampler_inbatch = HardTripletsSampler(norm_required=False)
        criterion = TripletMarginLossWithSampler(
            margin=0.5, sampler_inbatch=sampler_inbatch)

        # 4. training with catalyst Runner
        class CustomRunner(dl.SupervisedRunner):
            def handle_batch(self, batch) -> None:
                images, targets = batch["features"].float(
                ), batch["targets"].long()
                features = self.model(images)
                self.batch = {
                    "embeddings": features,
                    "targets": targets,
                }

        callbacks = [
            dl.ControlFlowCallbackWrapper(
                dl.CriterionCallback(input_key="embeddings",
                                     target_key="targets",
                                     metric_key="loss"),
                loaders="train",
            ),
            dl.SklearnModelCallback(
                feature_key="embeddings",
                target_key="targets",
                train_loader="train",
                valid_loaders=["valid", "infer"],
                model_fn=RandomForestClassifier,
                predict_method="predict_proba",
                predict_key="sklearn_predict",
                random_state=RANDOM_STATE,
                n_estimators=50,
            ),
            dl.ControlFlowCallbackWrapper(
                dl.AccuracyCallback(target_key="targets",
                                    input_key="sklearn_predict",
                                    topk=(1, 3)),
                loaders=["valid", "infer"],
            ),
        ]

        runner = CustomRunner(input_key="features", output_key="embeddings")
        runner.train(
            engine=engine,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            loaders={
                "train": train_loader,
                "valid": valid_loader,
                "infer": test_loader
            },
            verbose=False,
            valid_loader="valid",
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            num_epochs=TRAIN_EPOCH,
            logdir=logdir,
        )

        best_accuracy = max(
            epoch_metrics["infer"]["accuracy01"]
            for epoch_metrics in runner.experiment_metrics.values())

        assert best_accuracy > 0.8