def do_train_dirac_batch_same_size(verbose=True):
    def dirac(cur_batch, nb_of_batch):
        if cur_batch == 0:
            return 1
        else:
            return 0

    bay_net = GaussianClassifier(rho, number_of_classes=10)
    bay_net.to(device)
    loss_bbb = BBBLoss(bay_net, criterion, dirac)
    optimizer = optim.Adam(bay_net.parameters())
    observables = AccuracyAndUncertainty()
    train_bayesian_modular(
        bay_net,
        optimizer,
        loss_bbb,
        observables,
        number_of_epochs=nb_of_epochs,
        trainloader=trainloader,
        device=device,
        verbose=verbose,
    )

    return eval_bayesian(bay_net,
                         evalloader,
                         nb_of_tests,
                         device=device,
                         verbose=verbose)
    def test_training_no_validation():
        randomloader = RandomLoader()
        device = "cpu"
        bay_net = GaussianClassifier(-3, (0, 0), (1, 1), number_of_classes=10)
        bay_net.to(device)
        criterion = nn.CrossEntropyLoss()
        loss = BaseLoss(criterion)
        observables = AccuracyAndUncertainty()
        optimizer = optim.Adam(bay_net.parameters())

        train_bayesian_modular(bay_net,
                               optimizer,
                               loss,
                               observables,
                               1,
                               trainloader=randomloader,
                               device=device,
                               number_of_tests=2,
                               verbose=False)
def do_train_ce(verbose=True):
    bay_net = GaussianClassifier(rho, number_of_classes=10)
    bay_net.to(device)
    criterion = nn.CrossEntropyLoss()
    loss_bbb = BaseLoss(criterion)
    optimizer = optim.Adam(bay_net.parameters())
    observables = AccuracyAndUncertainty()
    train_bayesian_modular(
        bay_net,
        optimizer,
        loss_bbb,
        observables,
        number_of_epochs=nb_of_epochs,
        trainloader=trainloader,
        device=device,
        verbose=verbose,
    )

    return eval_bayesian(bay_net,
                         evalloader,
                         nb_of_tests,
                         device=device,
                         verbose=verbose)
    def step_function(batch_idx, number_of_batches):
        return 2**(number_of_batches - batch_idx) / (2**number_of_batches - 1)

    loss = BBBLoss(bay_net, criterion, step_function)
else:
    loss = BaseLoss(criterion)

optimizer = optim.Adam(bay_net.parameters())
observables = AccuracyAndUncertainty()
train_bayesian_modular(
    bay_net,
    optimizer,
    loss,
    observables,
    number_of_tests=number_of_tests,
    number_of_epochs=epoch,
    trainloader=trainloader,
    # valloader=valloader,
    # output_dir_tensorboard='./output',
    device=device,
    verbose=True,
)

true_train_labels, all_outputs_train = eval_bayesian(
    bay_net,
    trainloader,
    return_accuracy=False,
    number_of_tests=number_of_tests,
    device=device,
)