Beispiel #1
0
def semixup_data_provider(model, alpha, n_classes, train_labeled_data, train_unlabeled_data, val_labeled_data,
                          transforms, parse_item, bs, num_workers, item_loaders=dict(), root="", augmentation=None,
                          data_rearrange=None):
    """
    Default setting of data provider for Semixup
    """
    # item_loaders["labeled_train"] = MixUpSampler(meta_data=train_labeled_data, name='l_mixup', alpha=alpha, model=model,
    #                                              transform=transforms['train'], parse_item_cb=parse_item, batch_size=bs,
    #                                              data_rearrange=data_rearrange,
    #                                              num_workers=num_workers, root=root, shuffle=True)
    item_loaders["labeled_train"] = ItemLoader(meta_data=train_labeled_data, name='l_norm', transform=transforms['train'],
                                              parse_item_cb=parse_item, batch_size=bs, num_workers=num_workers,
                                              root=root, shuffle=True)

    item_loaders["unlabeled_train"] = SemixupSampler(meta_data=train_unlabeled_data, name='u_mixup', alpha=alpha,
                                                     model=model, min_lambda=0.55,
                                                     transform=transforms['train'], parse_item_cb=parse_item, batch_size=bs,
                                                     data_rearrange=data_rearrange,
                                                     num_workers=num_workers, augmentation=augmentation, root=root,
                                                     shuffle=True)

    item_loaders["labeled_eval"] = ItemLoader(meta_data=val_labeled_data, name='l_norm', transform=transforms['eval'],
                                              parse_item_cb=parse_item, batch_size=bs, num_workers=num_workers,
                                              root=root, shuffle=False)

    return DataProvider(item_loaders)
Beispiel #2
0
def create_data_provider(args, config, parser, metadata, mean, std):
    """
    Setup dataloader and augmentations
    :param args: General arguments
    :param config: Experiment parameters
    :param parser: Function for loading images
    :param metadata: Image paths and subject IDs
    :param mean: Dataset mean
    :param std: Dataset std
    :return: The compiled dataloader
    """
    # Compile ItemLoaders
    item_loaders = dict()
    for stage in ['train', 'val']:
        item_loaders[f'bfpn_{stage}'] = ItemLoader(
            meta_data=metadata[stage],
            transform=train_test_transforms(
                config,
                mean,
                std,
                crop_size=tuple(config['training']['crop_size']))[stage],
            parse_item_cb=parser,
            batch_size=config['training']['bs'],
            num_workers=args.num_threads,
            shuffle=True if stage == "train" else False)

    return DataProvider(item_loaders)
Beispiel #3
0
def main(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    data_dir = os.path.join(os.environ['PWD'], cfg.data_dir)

    train_ds, classes = get_cifar10(data_folder=data_dir, train=True)
    eval_ds, _ = get_cifar10(data_folder=data_dir, train=False)
    n_channels = 3

    criterion = torch.nn.CrossEntropyLoss()

    model = ResNet(in_channels=n_channels, n_features=64, drop_rate=0.3).to(device).half()
    optimizer = torch.optim.SGD(params=model.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.wd,
                                nesterov=True)

    # Tensorboard visualization
    log_dir = cfg.log_dir
    comment = cfg.comment
    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)

    item_loaders = dict()
    for stage, df in zip(['train', 'eval'], [train_ds, eval_ds]):
        item_loaders[f'loader_{stage}'] = ItemLoader(meta_data=df,
                                                     transform=my_transforms()[stage],
                                                     parse_item_cb=parse_item,
                                                     batch_size=cfg.bs, num_workers=cfg.num_workers,
                                                     shuffle=True if stage == "train" else False)

    data_provider = DataProvider(item_loaders)

    train_cbs = (CosineAnnealingWarmRestartsWithWarmup(optimizer=optimizer, warmup_epochs=(0, 10, 20),
                                                       warmup_lrs=(0, 0.1, 0.01), T_O=5, T_mult=2, eta_min=0),
                 RunningAverageMeter(name="loss"),
                 AccuracyMeter(name="acc"))

    val_cbs = (RunningAverageMeter(name="loss"),
               AccuracyMeter(name="acc"),
               ScalarMeterLogger(writer=summary_writer),
               ModelSaver(metric_names='loss', save_dir=cfg.snapshots, conditions='min', model=model),
               ModelSaver(metric_names='acc', save_dir=cfg.snapshots, conditions='max', model=model))

    session = dict()
    session['mymodel'] = Session(data_provider=data_provider,
                                 train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(),
                                 val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(),
                                 module=model, loss=criterion, optimizer=optimizer,
                                 train_callbacks=train_cbs,
                                 val_callbacks=val_cbs)

    strategy = Strategy(data_provider=data_provider,
                        data_sampling_config=cfg.sampling,
                        strategy_config=cfg.strategy,
                        sessions=session,
                        n_epochs=cfg.n_epochs,
                        device=device)

    strategy.run()
Beispiel #4
0
def ict_data_provider(model,
                      alpha,
                      n_classes,
                      train_labeled_data,
                      train_unlabeled_data,
                      val_labeled_data,
                      val_unlabeled_data,
                      transforms,
                      parse_item,
                      bs,
                      num_threads,
                      item_loaders=dict(),
                      root=""):
    """
    Default setting of data provider for ICT
    """
    item_loaders["labeled_train"] = ItemLoader(meta_data=train_labeled_data,
                                               name='l_norm',
                                               transform=transforms[1],
                                               parse_item_cb=parse_item,
                                               batch_size=bs,
                                               num_workers=num_threads,
                                               root=root,
                                               shuffle=True)

    item_loaders["unlabeled_train"] = MixUpSampler(
        meta_data=train_unlabeled_data,
        name='u_mixup',
        alpha=alpha,
        model=model,
        transform=transforms[0],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        root=root,
        shuffle=True)

    item_loaders["labeled_eval"] = ItemLoader(meta_data=val_labeled_data,
                                              name='l_norm',
                                              transform=transforms[1],
                                              parse_item_cb=parse_item,
                                              batch_size=bs,
                                              num_workers=num_threads,
                                              root=root,
                                              shuffle=False)

    item_loaders["unlabeled_eval"] = MixUpSampler(meta_data=val_unlabeled_data,
                                                  name='u_mixup',
                                                  alpha=alpha,
                                                  model=model,
                                                  transform=transforms[1],
                                                  parse_item_cb=parse_item,
                                                  batch_size=bs,
                                                  num_workers=num_threads,
                                                  root=root,
                                                  shuffle=False)

    return DataProvider(item_loaders)
Beispiel #5
0
def create_data_provider(args, config, parser, metadata, mean, std):
    # Compile ItemLoaders
    item_loaders = dict()
    for stage in ['train', 'val']:
        item_loaders[f'bfpn_{stage}'] = ItemLoader(meta_data=metadata[stage],
                                                   transform=train_test_transforms(config, mean, std)[stage],
                                                   parse_item_cb=parser,
                                                   batch_size=args.bs, num_workers=args.num_threads,
                                                   shuffle=True if stage == "train" else False)

    return DataProvider(item_loaders)
Beispiel #6
0
def distributed_gan_data_provider(g_network, item_loaders, train_ds, classes,
                                  latent_size, transforms, parse_item, args):
    """
    Default setting of data provider for GAN

    """
    bs = args.batch_size
    num_threads = args.workers
    gpu = args.gpu
    distributed = args.distributed
    if torch.cuda.is_available():
        device = torch.device('cuda:{}'.format(gpu))
    else:
        device = torch.device('cpu')
    if distributed:
        item_loaders['real'] = DistributedItemLoader(meta_data=train_ds,
                                                     transform=transforms,
                                                     parse_item_cb=parse_item,
                                                     args=args)
        item_loaders['fake'] = DistributedGANFakeSampler(
            g_network=g_network,
            batch_size=bs,
            latent_size=latent_size,
            gpu=gpu)

        item_loaders['noise'] = GaussianNoiseSampler(batch_size=bs,
                                                     latent_size=latent_size,
                                                     device=gpu,
                                                     n_classes=len(classes))

    else:
        item_loaders['real'] = ItemLoader(meta_data=train_ds,
                                          transform=transforms,
                                          parse_item_cb=parse_item,
                                          batch_size=bs,
                                          num_workers=num_threads,
                                          shuffle=True)

        item_loaders['fake'] = GANFakeSampler(g_network=g_network,
                                              batch_size=bs,
                                              latent_size=latent_size)

        item_loaders['noise'] = GaussianNoiseSampler(batch_size=bs,
                                                     latent_size=latent_size,
                                                     device=device,
                                                     n_classes=len(classes))

    return DataProvider(item_loaders)
Beispiel #7
0
def mixmatch_ema_data_provider(model,
                               augmentation,
                               labeled_meta_data,
                               unlabeled_meta_data,
                               val_labeled_data,
                               n_augmentations,
                               parse_item,
                               bs,
                               transforms,
                               root="",
                               num_threads=4):
    itemloader_dict = {}
    itemloader_dict['all_train'] = MixMatchSampler(
        model=model,
        name="train_mixmatch",
        augmentation=augmentation,
        labeled_meta_data=labeled_meta_data,
        unlabeled_meta_data=unlabeled_meta_data,
        n_augmentations=n_augmentations,
        data_key='data',
        target_key='target',
        parse_item_cb=parse_item,
        batch_size=bs,
        transform=transforms[0],
        num_workers=num_threads,
        shuffle=True)

    itemloader_dict['labeled_eval_st'] = ItemLoader(root=root,
                                                    meta_data=val_labeled_data,
                                                    name='l_eval',
                                                    transform=transforms[1],
                                                    parse_item_cb=parse_item,
                                                    batch_size=bs,
                                                    num_workers=num_threads,
                                                    shuffle=False)

    itemloader_dict['labeled_eval_te'] = ItemLoader(root=root,
                                                    meta_data=val_labeled_data,
                                                    name='l_eval',
                                                    transform=transforms[1],
                                                    parse_item_cb=parse_item,
                                                    batch_size=bs,
                                                    num_workers=num_threads,
                                                    shuffle=False)

    return DataProvider(itemloader_dict)
Beispiel #8
0
def init_data_provider(args, df_train, df_val, item_loaders, test_ds):
    """
    function to initialize data provider for the autoencoder
    Parameters
    ----------
    args: Namespace
        arguments for the whole network parsed using argparse
    df_train: DataFrame
        training data as pandas DataFrame
    df_val: DataFrame
        validation data as pandas DataFrame
    item_loaders: dict
        empty dictionary to be populated by data samplers
    test_ds: DataFrame
        test data for visualization as pandas DataFrame


    Returns
    -------
    DataProvider
        DataProvider object constructed from all the data samplers
    """
    for stage, df in zip(['train', 'eval'], [df_train, df_val]):
        item_loaders[f'mnist_{stage}'] = ItemLoader(
            meta_data=df,
            transform=init_mnist_transforms()[0],
            parse_item_cb=parse_item_ae,
            batch_size=args.bs,
            num_workers=args.num_threads,
            shuffle=True if stage == 'train' else False)

    item_loaders['mnist_viz'] = ItemLoader(
        meta_data=test_ds,
        transform=init_mnist_transforms()[0],
        parse_item_cb=parse_item_ae,
        batch_size=args.bs,
        num_workers=args.num_threads,
        shuffle=True if stage == 'train' else False)
    return DataProvider(item_loaders)
Beispiel #9
0
def gan_data_provider(g_network, item_loaders, train_ds, classes, latent_size,
                      transforms, parse_item, bs, num_threads, device):
    """
    Default setting of data provider for GAN

    """
    item_loaders['real'] = ItemLoader(meta_data=train_ds,
                                      transform=transforms,
                                      parse_item_cb=parse_item,
                                      batch_size=bs,
                                      num_workers=num_threads,
                                      shuffle=True)

    item_loaders['fake'] = GANFakeSampler(g_network=g_network,
                                          batch_size=bs,
                                          latent_size=latent_size)

    item_loaders['noise'] = GaussianNoiseSampler(batch_size=bs,
                                                 latent_size=latent_size,
                                                 device=device,
                                                 n_classes=len(classes))

    return DataProvider(item_loaders)
Beispiel #10
0
        num_workers=args.num_threads,
        shuffle=False)

    item_loaders['fake_unlabeled_gen'] = SSGANFakeSampler(
        g_network=g_network,
        batch_size=args.bs,
        latent_size=args.latent_size,
        n_classes=args.n_classes)

    item_loaders['fake_unlabeled_latent'] = SSGANFakeSampler(
        g_network=g_network,
        batch_size=args.bs,
        latent_size=args.latent_size,
        n_classes=args.n_classes)

    data_provider = DataProvider(item_loaders)

    # Callbacks
    g_callbacks_train = ClipGradCallback(g_network,
                                         mode="norm",
                                         max_norm=0.1,
                                         norm_type=2)

    d_callbacks_train = (SSValidityMeter(threshold=0.5,
                                         sigmoid=False,
                                         prefix="train/D",
                                         name="ss_valid"),
                         SSAccuracyMeter(prefix="train/D", name="ss_acc"))

    d_callbacks_eval = (SSValidityMeter(threshold=0.5,
                                        sigmoid=False,
Beispiel #11
0
def mt_data_provider(st_model,
                     te_model,
                     train_labeled_data,
                     train_unlabeled_data,
                     val_labeled_data,
                     val_unlabeled_data,
                     transforms,
                     parse_item,
                     bs,
                     num_threads,
                     item_loaders=dict(),
                     n_augmentations=1,
                     output_type='logits',
                     root=""):
    """
    Default setting of data provider for Mean-Teacher

    """

    # Train
    item_loaders["labeled_train_st"] = AugmentedGroupSampler(
        root=root,
        name='l_st',
        meta_data=train_labeled_data,
        model=st_model,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=True)

    item_loaders["unlabeled_train_st"] = AugmentedGroupSampler(
        root=root,
        name='u_st',
        model=st_model,
        meta_data=train_unlabeled_data,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=True)

    item_loaders["labeled_train_te"] = AugmentedGroupSampler(
        root=root,
        name='l_te',
        meta_data=train_labeled_data,
        model=te_model,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        detach=True,
        shuffle=True)

    item_loaders["unlabeled_train_te"] = AugmentedGroupSampler(
        root=root,
        name='u_te',
        model=te_model,
        meta_data=train_unlabeled_data,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        detach=True,
        shuffle=True)

    # Eval

    item_loaders["labeled_eval_st"] = AugmentedGroupSampler(
        root=root,
        name='l_st',
        meta_data=val_labeled_data,
        model=st_model,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=False)

    item_loaders["unlabeled_eval_st"] = AugmentedGroupSampler(
        root=root,
        name='u_st',
        model=st_model,
        meta_data=val_unlabeled_data,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[1],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=False)

    item_loaders["labeled_eval_te"] = ItemLoader(root=root,
                                                 meta_data=val_labeled_data,
                                                 name='l_te_eval',
                                                 transform=transforms[1],
                                                 parse_item_cb=parse_item,
                                                 batch_size=bs,
                                                 num_workers=num_threads,
                                                 shuffle=False)

    return DataProvider(item_loaders)
Beispiel #12
0
def pimodel_data_provider(model,
                          train_labeled_data,
                          train_unlabeled_data,
                          val_labeled_data,
                          val_unlabeled_data,
                          transforms,
                          parse_item,
                          bs,
                          num_threads,
                          item_loaders=dict(),
                          root="",
                          n_augmentations=1,
                          output_type='logits'):
    """
    Default setting of data provider for Pi-Model

    """
    item_loaders["labeled_train"] = ItemLoader(root=root,
                                               meta_data=train_labeled_data,
                                               name='l',
                                               transform=transforms[0],
                                               parse_item_cb=parse_item,
                                               batch_size=bs,
                                               num_workers=num_threads,
                                               shuffle=True)

    # item_loaders["labeled_train"] = AugmentedGroupSampler(root=root, model=model, name='l', output_type=output_type,
    #                                                       meta_data=train_labeled_data,
    #                                                       n_augmentations=n_augmentations,
    #                                                       augmentation=transforms[2],
    #                                                       transform=transforms[1],
    #                                                       parse_item_cb=parse_item,
    #                                                       batch_size=bs, num_workers=num_workers,
    #                                                       shuffle=True)

    item_loaders["unlabeled_train"] = AugmentedGroupSampler(
        root=root,
        model=model,
        name='u',
        output_type=output_type,
        meta_data=train_unlabeled_data,
        n_augmentations=n_augmentations,
        augmentation=transforms[2],
        transform=transforms[0],
        parse_item_cb=parse_item,
        batch_size=bs,
        num_workers=num_threads,
        shuffle=True)

    item_loaders["labeled_eval"] = ItemLoader(root=root,
                                              meta_data=val_labeled_data,
                                              name='l',
                                              transform=transforms[1],
                                              parse_item_cb=parse_item,
                                              batch_size=bs,
                                              num_workers=num_threads,
                                              shuffle=False)

    # item_loaders["labeled_eval"] = AugmentedGroupSampler(root=root, model=model, name='l', output_type=output_type,
    #                                                      meta_data=val_labeled_data,
    #                                                      n_augmentations=n_augmentations,
    #                                                      augmentation=transforms[2],
    #                                                      transform=transforms[1],
    #                                                      parse_item_cb=parse_item,
    #                                                      batch_size=bs, num_workers=num_workers,
    #                                                      shuffle=False)

    # item_loaders["unlabeled_eval"] = AugmentedGroupSampler(root=root, model=model, name='u', output_type=output_type,
    #                                                        meta_data=val_unlabeled_data,
    #                                                        n_augmentations=n_augmentations,
    #                                                        augmentation=transforms[2],
    #                                                        transform=transforms[1],
    #                                                        parse_item_cb=parse_item,
    #                                                        batch_size=bs, num_workers=num_workers,
    #                                                        shuffle=False)

    return DataProvider(item_loaders)
Beispiel #13
0
def worker_process(gpu, ngpus, sampling_config, strategy_config, args):
    args.gpu = gpu  # this line of code is not redundant
    if args.distributed:
        lr_m = float(args.batch_size * args.world_size) / 256.
    else:
        lr_m = 1.0
    criterion = torch.nn.CrossEntropyLoss().to(gpu)
    train_ds, classes = get_mnist(data_folder=args.save_data, train=True)
    test_ds, _ = get_mnist(data_folder=args.save_data, train=False)
    model = SimpleConvNet(bw=args.bw,
                          drop=args.dropout,
                          n_cls=len(classes),
                          n_channels=args.n_channels).to(gpu)
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=args.lr * lr_m,
                                 weight_decay=args.wd)

    args, model, optimizer = convert_according_to_args(args=args,
                                                       gpu=gpu,
                                                       ngpus=ngpus,
                                                       network=model,
                                                       optim=optimizer)

    item_loaders = dict()
    for stage, df in zip(['train', 'eval'], [train_ds, test_ds]):
        if args.distributed:
            item_loaders[f'mnist_{stage}'] = DistributedItemLoader(
                meta_data=df,
                transform=init_mnist_cifar_transforms(1, stage),
                parse_item_cb=parse_item_mnist,
                args=args)
        else:
            item_loaders[f'mnist_{stage}'] = ItemLoader(
                meta_data=df,
                transform=init_mnist_cifar_transforms(1, stage),
                parse_item_cb=parse_item_mnist,
                batch_size=args.batch_size,
                num_workers=args.workers,
                shuffle=True if stage == "train" else False)
    data_provider = DataProvider(item_loaders)
    if args.gpu == 0:
        log_dir = args.log_dir
        comment = args.comment
        summary_writer = SummaryWriter(log_dir=log_dir,
                                       comment='_' + comment + 'gpu_' +
                                       str(args.gpu))
        train_cbs = (RunningAverageMeter(prefix="train", name="loss"),
                     AccuracyMeter(prefix="train", name="acc"))

        val_cbs = (RunningAverageMeter(prefix="eval", name="loss"),
                   AccuracyMeter(prefix="eval", name="acc"),
                   ScalarMeterLogger(writer=summary_writer),
                   ModelSaver(metric_names='eval/loss',
                              save_dir=args.snapshots,
                              conditions='min',
                              model=model))
    else:
        train_cbs = ()
        val_cbs = ()

    strategy = Strategy(data_provider=data_provider,
                        train_loader_names=tuple(
                            sampling_config['train']['data_provider'].keys()),
                        val_loader_names=tuple(
                            sampling_config['eval']['data_provider'].keys()),
                        data_sampling_config=sampling_config,
                        loss=criterion,
                        model=model,
                        n_epochs=args.n_epochs,
                        optimizer=optimizer,
                        train_callbacks=train_cbs,
                        val_callbacks=val_cbs,
                        device=torch.device('cuda:{}'.format(args.gpu)),
                        distributed=args.distributed,
                        use_apex=args.use_apex)

    strategy.run()
Beispiel #14
0
def main(cfg):
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)
    random.seed(cfg.seed)

    data_dir = os.path.join(os.environ['PWD'], cfg.data_dir)

    train_ds, classes = get_mnist(data_folder=data_dir, train=True)
    n_classes = len(classes)
    n_channels = 1

    criterion = torch.nn.CrossEntropyLoss()

    # Tensorboard visualization
    log_dir = cfg.log_dir
    comment = cfg.comment
    summary_writer = SummaryWriter(log_dir=log_dir, comment=comment)

    splitter = FoldSplit(train_ds, n_folds=5, target_col="target")

    for fold_id, (df_train, df_val) in enumerate(splitter):
        item_loaders = dict()

        for stage, df in zip(['train', 'eval'], [df_train, df_val]):
            item_loaders[f'loader_{stage}'] = ItemLoader(
                meta_data=df,
                transform=my_transforms()[stage],
                parse_item_cb=parse_item,
                batch_size=cfg.bs,
                num_workers=cfg.num_threads,
                shuffle=True if stage == "train" else False)

        model = SimpleConvNet(bw=cfg.bw,
                              drop_rate=cfg.dropout,
                              n_classes=n_classes).to(device)
        optimizer = torch.optim.Adam(params=model.parameters(),
                                     lr=cfg.lr,
                                     weight_decay=cfg.wd)
        data_provider = DataProvider(item_loaders)

        train_cbs = (RunningAverageMeter(name="loss"),
                     AccuracyMeter(name="acc"))

        val_cbs = (RunningAverageMeter(name="loss"), AccuracyMeter(name="acc"),
                   ScalarMeterLogger(writer=summary_writer),
                   ModelSaver(metric_names='loss',
                              save_dir=cfg.snapshots,
                              conditions='min',
                              model=model),
                   ModelSaver(metric_names='acc',
                              save_dir=cfg.snapshots,
                              conditions='max',
                              model=model))

        session = dict()
        session['mymodel'] = Session(
            data_provider=data_provider,
            train_loader_names=cfg.sampling.train.data_provider.mymodel.keys(),
            val_loader_names=cfg.sampling.eval.data_provider.mymodel.keys(),
            module=model,
            loss=criterion,
            optimizer=optimizer,
            train_callbacks=train_cbs,
            val_callbacks=val_cbs)

        strategy = Strategy(data_provider=data_provider,
                            data_sampling_config=cfg.sampling,
                            strategy_config=cfg.strategy,
                            sessions=session,
                            n_epochs=cfg.n_epochs,
                            device=device)

        strategy.run()