def do_train_dirac_one_image(verbose=True):
    def dirac(cur_batch, nb_of_batch):
        if cur_batch == -1:
            return 1
        else:
            return 0

    bay_net = GaussianClassifier(rho, number_of_classes=10)
    bay_net.to(device)
    loss_bbb = BBBLoss(bay_net, criterion, dirac)
    optimizer = optim.Adam(bay_net.parameters())
    observables = AccuracyAndUncertainty()
    train_bayesian_modular_with_one_different(
        bay_net,
        optimizer,
        loss_bbb,
        observables,
        number_of_epochs=nb_of_epochs,
        trainloader=trainloader,
        device=device,
        verbose=verbose,
    )

    return eval_bayesian(bay_net,
                         evalloader,
                         nb_of_tests,
                         device=device,
                         verbose=verbose)
    def test_training_no_validation():
        randomloader = RandomLoader()
        device = "cpu"
        bay_net = GaussianClassifier(-3, (0, 0), (1, 1), number_of_classes=10)
        bay_net.to(device)
        criterion = nn.CrossEntropyLoss()
        loss = BaseLoss(criterion)
        observables = AccuracyAndUncertainty()
        optimizer = optim.Adam(bay_net.parameters())

        train_bayesian_modular(bay_net,
                               optimizer,
                               loss,
                               observables,
                               1,
                               trainloader=randomloader,
                               device=device,
                               number_of_tests=2,
                               verbose=False)
    def test_identity_of_direct_losses():
        device='cpu'
        randomloader = RandomLoader()

        bay_net = GaussianClassifier(-3, number_of_classes=10)
        bay_net.to(device)
        get_train_data = iter(randomloader)
        inputs, labels = next(get_train_data)
        outputs = bay_net(inputs)

        criterion = nn.CrossEntropyLoss()
        bbb_loss = BBBLoss(bay_net, criterion, uniform)

        number_of_batch = 100
        batch_idx = np.random.randint(100)

        bbb_loss.set_number_of_epoch(1)
        bbb_loss.set_current_epoch(0)
        bbb_loss.set_number_of_batch(number_of_batch)
        bbb_loss.set_current_batch_idx(batch_idx)

        bbb_loss.compute(outputs, labels)
    def test_identity_of_optimizer_step():
        device = 'cpu'
        randomloader = RandomLoader()

        get_train_data = iter(randomloader)
        inputs, labels = next(get_train_data)
        number_of_batch = 100
        batch_idx = np.random.randint(100)

        bay_net_new = GaussianClassifier(-3, number_of_classes=10)
        bay_net_new.to(device)
        optimizer = optim.Adam(bay_net_new.parameters())
        criterion = nn.CrossEntropyLoss()
        outputs = bay_net_new(inputs)
        bbb_loss = BBBLoss(bay_net_new, criterion, uniform)

        bbb_loss.set_number_of_epoch(1)
        bbb_loss.set_current_epoch(0)
        bbb_loss.set_number_of_batch(number_of_batch)
        bbb_loss.set_current_batch_idx(batch_idx)
        bbb_loss.compute(outputs, labels)
        bbb_loss.backward()
        optimizer.step()
def do_train_ce(verbose=True):
    bay_net = GaussianClassifier(rho, number_of_classes=10)
    bay_net.to(device)
    criterion = nn.CrossEntropyLoss()
    loss_bbb = BaseLoss(criterion)
    optimizer = optim.Adam(bay_net.parameters())
    observables = AccuracyAndUncertainty()
    train_bayesian_modular(
        bay_net,
        optimizer,
        loss_bbb,
        observables,
        number_of_epochs=nb_of_epochs,
        trainloader=trainloader,
        device=device,
        verbose=verbose,
    )

    return eval_bayesian(bay_net,
                         evalloader,
                         nb_of_tests,
                         device=device,
                         verbose=verbose)
if trainset == 'mnist':
    trainloader, valloader, evalloader = get_mnist(train_labels=range(10),
                                                   eval_labels=range(10),
                                                   batch_size=batch_size)
    dim_input = 28
    dim_channels = 1
if trainset == 'cifar10':
    trainloader, evalloader = get_cifar10(batch_size=batch_size)
    dim_input = 32
    dim_channels = 3

seed_model = set_and_print_random_seed()
bay_net = GaussianClassifier(rho=rho,
                             stds_prior=stds_prior,
                             dim_input=dim_input,
                             number_of_classes=10,
                             dim_channels=dim_channels)
bay_net.to(device)
criterion = CrossEntropyLoss()
if loss_type == 'uniform':
    step_function = uniform
    loss = BBBLoss(bay_net, criterion, step_function)
elif loss_type == 'exp':

    def step_function(batch_idx, number_of_batches):
        return 2**(number_of_batches - batch_idx) / (2**number_of_batches - 1)

    loss = BBBLoss(bay_net, criterion, step_function)
else:
    loss = BaseLoss(criterion)
trainloader_seen, valloader_seen, evalloader_seen = get_trainset(
    train_labels=train_labels,
    eval_labels=train_labels,
    batch_size=batch_size,
    split_train=(0, split_train))

# Defining unseen evaluation set
evalloader_unseen = get_evalloader_unseen(arguments)

# Defining model
seed_model = set_and_print_random_seed()
bay_net = GaussianClassifier(
    rho=rho,
    stds_prior=stds_prior,
    dim_input=arguments['dim_input'],
    number_of_classes=10,
    dim_channels=arguments['dim_channels'],
)
bay_net.to(device)
# Defining loss
criterion = CrossEntropyLoss()
if args.determinist:
    loss = BaseLoss(criterion)
    loss_type = 'criterion'
elif loss_type == 'uniform':
    step_function = uniform
    loss = BBBLoss(bay_net, criterion, step_function)
elif loss_type == 'exp':

    def step_function(batch_idx, number_of_batches):