示例#1
0
def main():
    parser = argparse.ArgumentParser(
        description="BatchBALD",
        formatter_class=functools.partial(
            argparse.ArgumentDefaultsHelpFormatter, width=120))
    parser.add_argument("--experiment_id",
                        type=str,
                        default='default_group',
                        help="experiment id")
    parser.add_argument("--run_name",
                        type=str,
                        default='default_run',
                        help='instant run name')

    parser.add_argument(
        "--experiments_laaos",
        type=str,
        default=None,
        help="Laaos file that contains all experiment task configs")
    parser.add_argument("--experiment_description",
                        type=str,
                        default="Trying stuff..",
                        help="Description of the experiment")
    parser = create_experiment_config_argparser(parser)
    args = parser.parse_args()

    if args.experiments_laaos is not None:
        config = laaos.safe_load(args.experiments_laaos,
                                 expose_symbols=(AcquisitionFunction,
                                                 AcquisitionMethod,
                                                 DatasetEnum))
        # Merge the experiment config with args.
        # Args take priority.
        args = parser.parse_args(namespace=argparse.Namespace(
            **config[args.experiment_id]))

    wandb_run = wandb.init(project='batchbald-reproduce',
                           entity='skoltech-nlp',
                           group=args.experiment_id,
                           name=args.run_name,
                           notes=args.experiment_description,
                           config=args.__dict__,
                           tags=[
                               str(args.type),
                               str(args.dataset),
                               f"MC={args.num_inference_samples}",
                               f"AL-STEP-SIZE={args.available_sample_k}"
                           ])

    # DONT TRUNCATE LOG FILES EVER AGAIN!!! (OFC THIS HAD TO HAPPEN AND BE PAINFUL)
    reduced_dataset = args.quickquick
    if args.experiment_id:
        store_name = args.experiment_id
        if reduced_dataset:
            store_name = "quickquick_" + store_name
    else:
        store_name = "results"

    # Make sure we have a directory to store the results in, and we don't crash!
    os.makedirs("./laaos", exist_ok=True)
    store = laaos.create_file_store(
        store_name,
        suffix="",
        truncate=False,
        type_handlers=(blackhc.laaos.StrEnumHandler(),
                       blackhc.laaos.ToReprHandler()),
    )
    store["args"] = args.__dict__
    store["cmdline"] = sys.argv[:]

    print("|".join(sys.argv))
    print(args.__dict__)

    acquisition_method: AcquisitionMethod = args.acquisition_method

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    print(f"Using {device} for computations")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    dataset: DatasetEnum = args.dataset
    samples_per_class = args.initial_samples_per_class
    validation_set_size = args.validation_set_size
    balanced_test_set = args.balanced_test_set
    balanced_validation_set = args.balanced_validation_set

    experiment_data = get_experiment_data(
        data_source=dataset.get_data_source(),
        num_classes=dataset.num_classes,
        initial_samples=args.initial_samples,
        reduced_dataset=reduced_dataset,
        samples_per_class=samples_per_class,
        validation_set_size=validation_set_size,
        balanced_test_set=balanced_test_set,
        balanced_validation_set=balanced_validation_set,
    )

    test_loader = torch.utils.data.DataLoader(experiment_data.test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    train_loader = torch.utils.data.DataLoader(
        experiment_data.train_dataset,
        sampler=RandomFixedLengthSampler(experiment_data.train_dataset,
                                         args.epoch_samples),
        batch_size=args.batch_size,
        **kwargs,
    )

    available_loader = torch.utils.data.DataLoader(
        experiment_data.available_dataset,
        batch_size=args.scoring_batch_size,
        shuffle=False,
        **kwargs)

    validation_loader = torch.utils.data.DataLoader(
        experiment_data.validation_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        **kwargs)

    store["iterations"] = []
    # store wraps the empty list in a storable list, so we need to fetch it separately.
    iterations = store["iterations"]

    store["initial_samples"] = experiment_data.initial_samples

    acquisition_function: AcquisitionFunction = args.type
    max_epochs = args.epochs

    for iteration in itertools.count(1):

        def desc(name):
            return lambda engine: "%s: %s (%s samples)" % (
                name, iteration, len(experiment_data.train_dataset))

        with ContextStopwatch() as train_model_stopwatch:
            early_stopping_patience = args.early_stopping_patience
            num_inference_samples = args.num_inference_samples
            log_interval = args.log_interval

            model, num_epochs, test_metrics = dataset.train_model(
                train_loader,
                test_loader,
                validation_loader,
                num_inference_samples,
                max_epochs,
                early_stopping_patience,
                desc,
                log_interval,
                device,
            )

        with ContextStopwatch() as batch_acquisition_stopwatch:
            batch = acquisition_method.acquire_batch(
                bayesian_model=model,
                acquisition_function=acquisition_function,
                available_loader=available_loader,
                num_classes=dataset.num_classes,
                k=args.num_inference_samples,
                b=args.available_sample_k,
                min_candidates_per_acquired_item=args.
                min_candidates_per_acquired_item,
                min_remaining_percentage=args.min_remaining_percentage,
                initial_percentage=args.initial_percentage,
                reduce_percentage=args.reduce_percentage,
                device=device,
            )

        original_batch_indices = get_base_indices(
            experiment_data.available_dataset, batch.indices)
        print(f"Acquiring indices {original_batch_indices}")
        targets = get_targets(experiment_data.available_dataset)
        acquired_targets = [int(targets[index]) for index in batch.indices]
        print(f"Acquiring targets {acquired_targets}")

        iterations.append(
            dict(
                num_epochs=num_epochs,
                test_metrics=test_metrics,
                chosen_targets=acquired_targets,
                chosen_samples=original_batch_indices,
                chosen_samples_score=batch.scores,
                chosen_samples_orignal_score=batch.orignal_scores,
                train_model_elapsed_time=train_model_stopwatch.elapsed_time,
                batch_acquisition_elapsed_time=batch_acquisition_stopwatch.
                elapsed_time,
            ))
        wandb.log({
            'accuracy': test_metrics['accuracy'],
            'nll': test_metrics['nll'],
            'aquisition_elapsed_time':
            batch_acquisition_stopwatch.elapsed_time,
            'training_elapsed_time': train_model_stopwatch.elapsed_time
        })

        experiment_data.active_learning_data.acquire(batch.indices)

        num_acquired_samples = len(
            experiment_data.active_learning_data.active_dataset) - len(
                experiment_data.initial_samples)
        if num_acquired_samples >= args.target_num_acquired_samples:
            print(
                f"{num_acquired_samples} acquired samples >= {args.target_num_acquired_samples}"
            )
            break
        #if test_metrics["accuracy"] >= args.target_accuracy:
        #    print(f'accuracy {test_metrics["accuracy"]} >= {args.target_accuracy}')
        #    break

    print("DONE")
示例#2
0
def recover_model(laaos_store, target_iteration=None):
    args = recover_args(laaos_store)
    sample_indices = get_samples_from_laaos_store(laaos_store,
                                                  target_iteration)

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    print(f"Using {device} for computations")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    dataset: dataset_enum.DatasetEnum = args.dataset
    experiment_data = dataset_enum.get_experiment_data(
        dataset.get_data_source(),
        dataset.num_classes,
        sample_indices,
        False,
        0,
        args.validation_set_size,
        args.balanced_test_set,
        args.balanced_validation_set,
    )

    test_loader = torch.utils.data.DataLoader(experiment_data.test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    train_loader = torch.utils.data.DataLoader(
        experiment_data.train_dataset,
        sampler=RandomFixedLengthSampler(experiment_data.train_dataset,
                                         args.epoch_samples),
        batch_size=args.batch_size,
        **kwargs,
    )

    validation_loader = torch.utils.data.DataLoader(
        experiment_data.validation_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        **kwargs)

    log_interval = args.log_interval
    num_inference_samples = args.num_inference_samples
    early_stopping_patience = args.early_stopping_patience
    max_epochs = args.epochs

    def desc(name):
        return lambda engine: "%s" % name

    model, num_epochs, test_metrics = dataset.train_model(
        train_loader,
        test_loader,
        validation_loader,
        num_inference_samples,
        max_epochs,
        early_stopping_patience,
        desc,
        log_interval,
        device,
    )

    return RecoveredModel(
        args,
        model,
        num_epochs,
        test_metrics,
        experiment_data.active_learning_data,
        experiment_data.validation_dataset,
        experiment_data.test_dataset,
        Loaders(train_loader, test_loader, validation_loader),
    )
示例#3
0
def main():
    parser = argparse.ArgumentParser(
        description="BatchBALD",
        formatter_class=functools.partial(
            argparse.ArgumentDefaultsHelpFormatter, width=120))
    parser.add_argument("--experiment_task_id",
                        type=str,
                        default=None,
                        help="experiment id")
    parser.add_argument(
        "--experiments_laaos",
        type=str,
        default=None,
        help="Laaos file that contains all experiment task configs")
    parser.add_argument("--experiment_description",
                        type=str,
                        default="Trying stuff..",
                        help="Description of the experiment")
    parser = create_experiment_config_argparser(parser)
    args = parser.parse_args()

    if args.gpu == -1:
        gpu_id = gpu_init(best_gpu_metric="mem")
    else:
        gpu_id = gpu_init(gpu_id=args.gpu)
    print("Running on GPU " + str(gpu_id))

    if args.experiments_laaos is not None:
        config = laaos.safe_load(args.experiments_laaos,
                                 expose_symbols=(AcquisitionFunction,
                                                 AcquisitionMethod,
                                                 DatasetEnum))
        # Merge the experiment config with args.
        # Args take priority.
        args = parser.parse_args(namespace=argparse.Namespace(
            **config[args.experiment_task_id]))

    # DONT TRUNCATE LOG FILES EVER AGAIN!!! (OFC THIS HAD TO HAPPEN AND BE PAINFUL)
    reduced_dataset = args.quickquick
    if args.experiment_task_id:
        store_name = args.experiment_task_id
        if reduced_dataset:
            store_name = "quickquick_" + store_name
    else:
        store_name = "results"

    # Make sure we have a directory to store the results in, and we don't crash!
    os.makedirs("./laaos", exist_ok=True)
    store = laaos.create_file_store(
        store_name,
        suffix="",
        truncate=False,
        type_handlers=(blackhc.laaos.StrEnumHandler(),
                       blackhc.laaos.ToReprHandler()),
    )
    store["args"] = args.__dict__
    store["cmdline"] = sys.argv[:]

    print("|".join(sys.argv))
    print(args.__dict__)

    acquisition_method: AcquisitionMethod = args.acquisition_method

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    if args.fix_numpy_python_seed:
        np.random.seed(args.seed)
        random.seed(args.seed)
    if args.cudnn_deterministic:
        torch.backends.cudnn.deterministic = True
        #torch.backends.cudnn.benchmark = False
    device = torch.device("cuda" if use_cuda else "cpu")

    print(f"Using {device} for computations")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    dataset: DatasetEnum = args.dataset
    samples_per_class = args.initial_samples_per_class
    validation_set_size = args.validation_set_size
    balanced_test_set = args.balanced_test_set
    balanced_validation_set = args.balanced_validation_set

    if args.file_with_initial_samples != "":
        if args.initial_samples is None:
            args.initial_samples = []
        num_read = 0
        with open(args.file_with_initial_samples) as f:
            for line in f:
                cur_samples = []
                if line.startswith("store['initial_samples']"):
                    cur_samples = [
                        int(k)
                        for k in line.strip().split('=')[1][1:-1].split(',')
                    ]
                    num_read += 1
                elif "chosen_targets" in line:
                    line = line.strip().split("'chosen_targets': [")[1]
                    line = line.split("]")[0]
                    cur_samples = [int(k) for k in line.split(',')]
                    num_read += 1
                args.initial_samples += cur_samples
                if num_read >= args.max_num_batch_init_samples_to_read:
                    break

    experiment_data = get_experiment_data(
        data_source=dataset.get_data_source(),
        num_classes=dataset.num_classes,
        initial_samples=args.initial_samples,
        reduced_dataset=reduced_dataset,
        samples_per_class=samples_per_class,
        validation_set_size=validation_set_size,
        balanced_test_set=balanced_test_set,
        balanced_validation_set=balanced_validation_set,
    )

    test_loader = torch.utils.data.DataLoader(experiment_data.test_dataset,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    train_loader = torch.utils.data.DataLoader(
        experiment_data.train_dataset,
        sampler=RandomFixedLengthSampler(experiment_data.train_dataset,
                                         args.epoch_samples),
        batch_size=args.batch_size,
        **kwargs,
    )

    available_loader = torch.utils.data.DataLoader(
        experiment_data.available_dataset,
        batch_size=args.scoring_batch_size,
        shuffle=False,
        **kwargs)

    validation_loader = torch.utils.data.DataLoader(
        experiment_data.validation_dataset,
        batch_size=args.test_batch_size,
        shuffle=False,
        **kwargs)

    store["iterations"] = []
    # store wraps the empty list in a storable list, so we need to fetch it separately.
    iterations = store["iterations"]

    store["initial_samples"] = experiment_data.initial_samples

    acquisition_function: AcquisitionFunction = args.type
    max_epochs = args.epochs

    for iteration in itertools.count(1):

        def desc(name):
            return lambda engine: "%s: %s (%s samples)" % (
                name, iteration, len(experiment_data.train_dataset))

        with ContextStopwatch() as train_model_stopwatch:
            early_stopping_patience = args.early_stopping_patience
            num_inference_samples = args.num_inference_samples
            log_interval = args.log_interval

            model, num_epochs, test_metrics = dataset.train_model(
                train_loader,
                test_loader,
                validation_loader,
                num_inference_samples,
                max_epochs,
                early_stopping_patience,
                desc,
                log_interval,
                device,
            )
        target_size = max(
            args.min_candidates_per_acquired_item * args.available_sample_k,
            len(available_loader.dataset) * args.min_remaining_percentage //
            100)
        result = reduced_eval_consistent_bayesian_model(
            bayesian_model=model,
            acquisition_function=AcquisitionFunction.predictive_entropy,
            num_classes=dataset.num_classes,
            k=args.num_inference_samples,
            initial_percentage=args.initial_percentage,
            reduce_percentage=args.reduce_percentage,
            target_size=target_size,
            available_loader=available_loader,
            device=device,
        )
        print("entropy score shape:", result.scores_B.numpy().shape)
        entropy_score = result.scores_B.numpy().mean()
        to_store = {}
        with ContextStopwatch() as batch_acquisition_stopwatch:
            ret = acquisition_method.acquire_batch(
                bayesian_model=model,
                acquisition_function=acquisition_function,
                available_loader=available_loader,
                num_classes=dataset.num_classes,
                k=args.num_inference_samples,
                b=args.available_sample_k,
                min_candidates_per_acquired_item=args.
                min_candidates_per_acquired_item,
                min_remaining_percentage=args.min_remaining_percentage,
                initial_percentage=args.initial_percentage,
                reduce_percentage=args.reduce_percentage,
                max_batch_compute_size=args.max_batch_compute_size,
                hsic_compute_batch_size=args.hsic_compute_batch_size,
                hsic_kernel_name=args.hsic_kernel_name,
                fass_entropy_bag_size_factor=args.fass_entropy_bag_size_factor,
                hsic_resample=args.hsic_resample,
                ical_max_greedy_iterations=args.ical_max_greedy_iterations,
                device=device,
                store=to_store,
                random_ical_minibatch=args.random_ical_minibatch,
                num_to_condense=args.num_to_condense,
                num_inference_for_marginal_stat=args.
                num_inference_for_marginal_stat,
                use_orig_condense=args.use_orig_condense,
            )
            if type(ret) is tuple:
                batch, time_taken = ret
            else:
                batch = ret

        probs_B_K_C = result.logits_B_K_C.exp()
        B, K, C = list(
            probs_B_K_C.shape)  # (pool size, num samples, num classes)
        probs_B_C = probs_B_K_C.mean(dim=1)
        prior_entropy = -(probs_B_C * probs_B_C.log()).sum(
            dim=-1)  # (batch_size,)
        mi = 0.
        post_entropy = 0.
        sample_B_K_C = generate_sample(probs_B_K_C)
        for i, idx in enumerate(batch.indices[:100]):
            cur_post_entropy = compute_pair_mi(idx, i, probs_B_K_C, probs_B_C,
                                               sample_B_K_C)
            mi += (prior_entropy[i] - cur_post_entropy) / len(batch.indices)
            post_entropy += cur_post_entropy / len(batch.indices)
        mi = mi.item()
        post_entropy = post_entropy.item()
        print('post_entropy', post_entropy, 'mi', mi)

        prior_entropy, mi = compute_mi_sample(probs_B_K_C,
                                              sample_B_K_C,
                                              num_samples=50)
        print('prior_entropy', prior_entropy, 'unpooled interdependency', mi)

        original_batch_indices = get_base_indices(
            experiment_data.available_dataset, batch.indices)
        print(f"Acquiring indices {original_batch_indices}")
        targets = get_targets(experiment_data.available_dataset)
        acquired_targets = [int(targets[index]) for index in batch.indices]
        print(f"Acquiring targets {acquired_targets}")

        iterations.append(
            dict(
                num_epochs=num_epochs,
                test_metrics=test_metrics,
                active_entropy=entropy_score,
                chosen_targets=acquired_targets,
                chosen_samples=original_batch_indices,
                chosen_samples_score=batch.scores,
                chosen_samples_orignal_score=batch.orignal_scores,
                train_model_elapsed_time=train_model_stopwatch.elapsed_time,
                batch_acquisition_elapsed_time=batch_acquisition_stopwatch.
                elapsed_time,
                prior_pool_entropy=prior_entropy,
                batch_pool_mi=mi,
                **to_store,
            ))

        experiment_data.active_learning_data.acquire(batch.indices)

        num_acquired_samples = len(
            experiment_data.active_learning_data.active_dataset) - len(
                experiment_data.initial_samples)
        if num_acquired_samples >= args.target_num_acquired_samples:
            print(
                f"{num_acquired_samples} acquired samples >= {args.target_num_acquired_samples}"
            )
            break
        if test_metrics["accuracy"] >= args.target_accuracy:
            print(
                f'accuracy {test_metrics["accuracy"]} >= {args.target_accuracy}'
            )
            break

    with ContextStopwatch() as train_model_stopwatch:
        early_stopping_patience = args.early_stopping_patience
        num_inference_samples = args.num_inference_samples
        log_interval = args.log_interval

        model, num_epochs, test_metrics = dataset.train_model(
            train_loader,
            test_loader,
            validation_loader,
            num_inference_samples,
            max_epochs,
            early_stopping_patience,
            desc,
            log_interval,
            device,
        )
    target_size = max(
        args.min_candidates_per_acquired_item * args.available_sample_k,
        len(available_loader.dataset) * args.min_remaining_percentage // 100)
    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=model,
        acquisition_function=AcquisitionFunction.predictive_entropy,
        num_classes=dataset.num_classes,
        k=args.num_inference_samples,
        initial_percentage=args.initial_percentage,
        reduce_percentage=args.reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )

    probs_B_K_C = result.logits_B_K_C.exp()
    B, K, C = list(probs_B_K_C.shape)  # (pool size, num samples, num classes)
    probs_B_C = probs_B_K_C.mean(dim=1)
    prior_entropy = -(probs_B_C * probs_B_C.log()).sum(dim=-1)  # (batch_size,)

    print('post_entropy', prior_entropy.mean().item(), 'mi', mi)

    print("DONE")
示例#4
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description="Pure training loop without AL",
        formatter_class=functools.partial(argparse.ArgumentDefaultsHelpFormatter, width=120),
    )
    parser.add_argument("--batch_size", type=int, default=64, help="input batch size for training")
    parser.add_argument("--scoring_batch_size", type=int, default=256, help="input batch size for scoring")
    parser.add_argument("--test_batch_size", type=int, default=256, help="input batch size for testing")
    parser.add_argument("--validation_set_size", type=int, default=128, help="validation set size")
    parser.add_argument(
        "--early_stopping_patience", type=int, default=1, help="# patience epochs for early stopping per iteration"
    )
    parser.add_argument("--epochs", type=int, default=30, help="number of epochs to train")
    parser.add_argument("--epoch_samples", type=int, default=5056, help="number of epochs to train")
    parser.add_argument("--quickquick", action="store_true", default=False, help="uses a very reduced dataset")
    parser.add_argument(
        "--balanced_validation_set",
        action="store_true",
        default=False,
        help="uses a balanced validation set (instead of randomly picked)"
        "(and if no validation set is provided by the dataset)",
    )
    parser.add_argument("--num_inference_samples", type=int, default=5, help="number of samples for inference")
    parser.add_argument("--no_cuda", action="store_true", default=False, help="disables CUDA training")
    parser.add_argument(
        "--name", type=str, default="results", help="name for the results file (name of the experiment)"
    )
    parser.add_argument("--seed", type=int, default=1, help="random seed")
    parser.add_argument(
        "--train_dataset_limit",
        type=int,
        default=0,
        help="how much of the training set to use for training after splitting off the validation set (0 for all)",
    )
    parser.add_argument(
        "--balanced_training_set",
        action="store_true",
        default=False,
        help="uses a balanced training set (instead of randomly picked)"
        "(and if no validation set is provided by the dataset)",
    )
    parser.add_argument(
        "--balanced_test_set",
        action="store_true",
        default=False,
        help="force balances the test set---use with CAUTION!",
    )
    parser.add_argument(
        "--log_interval", type=int, default=10, help="how many batches to wait before logging training status"
    )
    parser.add_argument(
        "--dataset",
        type=DatasetEnum,
        default=DatasetEnum.mnist,
        help=f"dataset to use (options: {[f.name for f in DatasetEnum]})",
    )
    args = parser.parse_args()

    store = laaos.create_file_store(
        args.name,
        suffix="",
        truncate=False,
        type_handlers=(blackhc.laaos.StrEnumHandler(), blackhc.laaos.ToReprHandler()),
    )
    store["args"] = args.__dict__

    print(args.__dict__)

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    print(f"Using {device} for computations")

    kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

    dataset: DatasetEnum = args.dataset

    data_source = dataset.get_data_source()

    reduced_train_length = args.train_dataset_limit

    experiment_data = get_experiment_data(
        data_source,
        dataset.num_classes,
        None,
        False,
        0,
        args.validation_set_size,
        args.balanced_test_set,
        args.balanced_validation_set,
    )

    if not reduced_train_length:
        reduced_train_length = len(experiment_data.available_dataset)

    print(f"Training with reduced dataset of {reduced_train_length} data points")
    if not args.balanced_training_set:
        experiment_data.active_learning_data.acquire(
            experiment_data.active_learning_data.get_random_available_indices(reduced_train_length)
        )
    else:
        print("Using a balanced training set.")
        num_samples_per_class = reduced_train_length // dataset.num_classes
        experiment_data.active_learning_data.acquire(
            list(
                itertools.chain.from_iterable(
                    torch_utils.get_balanced_sample_indices(
                        get_targets(experiment_data.available_dataset), dataset.num_classes, num_samples_per_class
                    ).values()
                )
            )
        )

    if len(experiment_data.train_dataset) < args.epoch_samples:
        sampler = RandomFixedLengthSampler(experiment_data.train_dataset, args.epoch_samples)
    else:
        sampler = data.RandomSampler(experiment_data.train_dataset)

    test_loader = torch.utils.data.DataLoader(
        experiment_data.test_dataset, batch_size=args.test_batch_size, shuffle=False, **kwargs
    )
    train_loader = torch.utils.data.DataLoader(
        experiment_data.train_dataset, sampler=sampler, batch_size=args.batch_size, **kwargs
    )

    validation_loader = torch.utils.data.DataLoader(
        experiment_data.validation_dataset, batch_size=args.test_batch_size, shuffle=False, **kwargs
    )

    def desc(name):
        return lambda engine: "%s" % name

    dataset.train_model(
        train_loader=train_loader,
        test_loader=test_loader,
        validation_loader=validation_loader,
        num_inference_samples=args.num_inference_samples,
        max_epochs=args.epochs,
        early_stopping_patience=args.early_stopping_patience,
        desc=desc,
        log_interval=args.log_interval,
        device=device,
        epoch_results_store=store,
    )