def get_bayes_task_objective(prm,
                             prior_model,
                             post_model,
                             n_samples,
                             empirical_loss,
                             hyper_kl=0,
                             n_train_tasks=1,
                             noised_prior=True):

    complexity_type = prm.complexity_type
    delta = prm.delta  #  maximal probability that the bound does not hold
    tot_kld = get_total_kld(
        prior_model, post_model, prm,
        noised_prior)  # KLD between posterior and sampled prior

    if complexity_type == 'NoComplexity':
        # set as zero
        complex_term = Variable(cmn.zeros_gpu(1), requires_grad=False)

    elif prm.complexity_type == 'NewBoundMcAllaster':
        complex_term = torch.sqrt(
            (1 / (2 * (n_samples - 1))) *
            (hyper_kl + tot_kld + math.log(2 * n_samples / delta)))

    elif prm.complexity_type == 'NewBoundSeeger':
        seeger_eps = (1 / n_samples) * (
            tot_kld + hyper_kl + math.log(4 * math.sqrt(n_samples) / delta))
        sqrt_arg = 2 * seeger_eps * empirical_loss
        sqrt_arg = F.relu(
            sqrt_arg)  # prevent negative values due to numerical errors
        complex_term = 2 * seeger_eps + torch.sqrt(sqrt_arg)

    elif complexity_type == 'PAC_Bayes_Pentina':
        complex_term = math.sqrt(1 / n_samples) * tot_kld + hyper_kl * (
            1 / (n_train_tasks * math.sqrt(n_samples)))

    elif complexity_type == 'Variational_Bayes':
        # Since we approximate the expectation of the likelihood of all samples,
        # we need to multiply by the average_loss by total number of samples
        empirical_loss = n_samples * empirical_loss
        complex_term = tot_kld

    # elif complexity_type == 'PAC_Bayes_Seeger':
    #     # Seeger complexity is unique since it requires the empirical loss
    #     # small_num = 1e-9 # to avoid nan due to numerical errors
    #     # seeger_eps = (1 / n_samples) * (kld + math.log(2 * math.sqrt(n_samples) / delta))
    #     seeger_eps = (1 / n_samples) * (tot_kld + math.log(2 * math.sqrt(n_samples) / delta))
    #     sqrt_arg = 2 * seeger_eps * task_empirical_loss
    #     sqrt_arg = F.relu(sqrt_arg)  # prevent negative values due to numerical errors
    #     complex_term = 2 * seeger_eps + torch.sqrt(sqrt_arg)

    # elif complexity_type == 'PAC_Bayes_McAllaster':
    #     complex_term = torch.sqrt((1 / (2 * n_samples)) * (tot_kld + math.log(2*math.sqrt(n_samples) / delta)))

    else:
        raise ValueError('Invalid complexity_type')

    return empirical_loss, complex_term
def run_test_majority_vote(model, test_loader, loss_criterion, prm, n_votes=9):
    #
    n_test_samples = len(test_loader.dataset)
    n_test_batches = len(test_loader)
    model.eval()
    test_loss = 0
    n_correct = 0
    for batch_data in test_loader:
        inputs, targets = data_gen.get_batch_vars(batch_data,
                                                  prm,
                                                  is_test=True)

        batch_size = inputs.shape[
            0]  # min(prm.test_batch_size, n_test_samples)
        info = data_gen.get_info(prm)
        n_labels = info['n_classes']
        votes = cmn.zeros_gpu((batch_size, n_labels))
        for i_vote in range(n_votes):

            outputs = model(inputs)
            test_loss += loss_criterion(outputs, targets)
            pred = outputs.data.max(
                1, keepdim=True)[1]  # get the index of the max output
            for i_sample in range(batch_size):
                pred_val = pred[i_sample].cpu().numpy()[0]
                votes[i_sample, pred_val] += 1

        majority_pred = votes.max(
            1, keepdim=True)[1]  # find argmax class for each sample
        n_correct += majority_pred.eq(
            targets.data.view_as(majority_pred)).cpu().sum()
    test_loss /= n_test_samples
    test_acc = n_correct / n_test_samples
    info = {
        'test_acc': test_acc,
        'n_correct': n_correct,
        'test_type': 'majority_vote',
        'n_test_samples': n_test_samples,
        'test_loss': get_value(test_loss)
    }
    return info
def run_test_avg_vote(model, test_loader, loss_criterion, prm, n_votes=5):

    n_test_samples = len(test_loader.dataset)
    n_test_batches = len(test_loader)
    model.eval()
    test_loss = 0
    n_correct = 0
    for batch_data in test_loader:
        inputs, targets = data_gen.get_batch_vars(batch_data,
                                                  prm,
                                                  is_test=True)

        batch_size = min(prm.test_batch_size, n_test_samples)
        info = data_gen.get_info(prm)
        n_labels = info['n_classes']
        votes = cmn.zeros_gpu((batch_size, n_labels))
        for i_vote in range(n_votes):

            outputs = model(inputs)
            test_loss += loss_criterion(outputs, targets)
            votes += outputs.data

        majority_pred = votes.max(1, keepdim=True)[1]
        n_correct += majority_pred.eq(
            targets.data.view_as(majority_pred)).cpu().sum()

    test_loss /= n_test_samples
    test_acc = n_correct / n_test_samples
    info = {
        'test_acc': test_acc,
        'n_correct': n_correct,
        'test_type': 'AvgVote',
        'n_test_samples': n_test_samples,
        'test_loss': get_value(test_loss)
    }
    return info
        else:
            avg_param_vec += parameters_to_vector(curr_model.parameters()) * (1 / n_train_tasks)

    avg_model = deterministic_models.get_model(prm)
    vector_to_parameters(avg_param_vec, avg_model.parameters())

    # create the prior model:
    prior_model = stochastic_models.get_model(prm)
    prior_layers_list = [layer for layer in prior_model.modules() if isinstance(layer, StochasticLayer)]
    avg_model_layers_list = [layer for layer in avg_model.modules()
                             if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear)]
    assert len(avg_model_layers_list)==len(prior_layers_list), "lists not equal"

    for i_layer, prior_layer in enumerate(prior_layers_list):
        if hasattr(prior_layer, 'w'):
            prior_layer.w['log_var'] = torch.nn.Parameter(zeros_gpu(1))
            prior_layer.w['mean'] = avg_model_layers_list[i_layer].weight
        if hasattr(prior_layer, 'b'):
            prior_layer.b['log_var'] = torch.nn.Parameter(zeros_gpu(1))
            prior_layer.b['mean'] = avg_model_layers_list[i_layer].bias


    # save learned prior:
    f_path = save_model_state(prior_model, save_path)
    print('Trained prior saved in ' + f_path)


elif prm.mode == 'LoadMetaModel':

    # Loads  previously training prior.
    # First, create the model:
Ejemplo n.º 5
0
def get_bayes_task_objective(prm,
                             prior_model,
                             post_model,
                             n_samples,
                             empirical_loss,
                             hyper_kl=0,
                             n_train_tasks=1,
                             noised_prior=True):

    complexity_type = prm.complexity_type
    delta = prm.delta  #  maximal probability that the bound does not hold
    tot_kld = get_total_kld(
        prior_model, post_model, prm,
        noised_prior)  # KLD between posterior and sampled prior

    if complexity_type == 'NoComplexity':
        # set as zero
        complex_term = Variable(cmn.zeros_gpu(1), requires_grad=False)

    elif prm.complexity_type == 'NewBoundMcAllaster':
        # complex_term = torch.sqrt((1 / (2 * (n_samples-1))) * (hyper_kl + tot_kld + math.log(2 * n_samples / delta)))
        complex_term = torch.sqrt(
            (hyper_kl + tot_kld +
             math.log(2 * n_samples * n_train_tasks / delta)) /
            (2 * (n_samples - 1)))

    elif prm.complexity_type == 'NewBoundSeeger':
        seeger_eps = (1 / n_samples) * (
            tot_kld + hyper_kl + math.log(4 * math.sqrt(n_samples) / delta))
        sqrt_arg = 2 * seeger_eps * empirical_loss
        sqrt_arg = F.relu(
            sqrt_arg)  # prevent negative values due to numerical errors
        complex_term = 2 * seeger_eps + torch.sqrt(sqrt_arg)

    elif complexity_type == 'PAC_Bayes_Pentina':
        complex_term = math.sqrt(1 / n_samples) * tot_kld + hyper_kl * (
            1 / (n_train_tasks * math.sqrt(n_samples)))

    elif complexity_type == 'Variational_Bayes':
        # Since we approximate the expectation of the likelihood of all samples,
        # we need to multiply by the average_loss by total number of samples
        empirical_loss = n_samples * empirical_loss
        complex_term = tot_kld

    # elif complexity_type == 'PAC_Bayes_Seeger':
    #     # Seeger complexity is unique since it requires the empirical loss
    #     # small_num = 1e-9 # to avoid nan due to numerical errors
    #     # seeger_eps = (1 / n_samples) * (kld + math.log(2 * math.sqrt(n_samples) / delta))
    #     seeger_eps = (1 / n_samples) * (tot_kld + math.log(2 * math.sqrt(n_samples) / delta))
    #     sqrt_arg = 2 * seeger_eps * task_empirical_loss
    #     sqrt_arg = F.relu(sqrt_arg)  # prevent negative values due to numerical errors
    #     complex_term = 2 * seeger_eps + torch.sqrt(sqrt_arg)

    # elif complexity_type == 'PAC_Bayes_McAllaster':
    #     complex_term = torch.sqrt((1 / (2 * n_samples)) * (tot_kld + math.log(2*math.sqrt(n_samples) / delta)))

    #******************************************** New PAC Bayes bound test ********************************************#
    elif prm.complexity_type == 'PAC_Bayes_quad':
        # Pérez-Ortiz M, Rivasplata O, Shawe-Taylor J, et al. Tighter risk certificates for neural networks[J]. arXiv preprint arXiv:2007.12911, 2020.
        quad_eps = (1 / (2 * n_samples)) * (tot_kld + hyper_kl + math.log(
            4 * n_train_tasks * math.sqrt(n_samples) / delta))
        sqrt_arg = (quad_eps + empirical_loss) * quad_eps
        # sqrt_arg = F.relu(sqrt_arg)  # prevent negative values due to numerical errors
        complex_term = 2 * (quad_eps + torch.sqrt(sqrt_arg))

    elif prm.complexity_type == 'PAC_Bayes_lambda':
        # Pérez-Ortiz M, Rivasplata O, Shawe-Taylor J, et al. Tighter risk certificates for neural networks[J]. arXiv preprint arXiv:2007.12911, 2020.
        PAC_lambda = 1
        complex_term = 1 / (n_samples * PAC_lambda * (1 - PAC_lambda/2)) * \
                       (tot_kld + hyper_kl + math.log(4 * n_train_tasks * math.sqrt(n_samples) / delta))

    elif prm.complexity_type == 'PAC_Bayes_variational_kl':
        quad_eps = (1 / (2 * n_samples)) * (tot_kld + hyper_kl + math.log(
            4 * n_train_tasks * math.sqrt(n_samples) / delta))
        sqrt_arg = (quad_eps + empirical_loss) * quad_eps
        complex_term_quad = 2 * (quad_eps + torch.sqrt(sqrt_arg))

        PAC_lambda = 1
        complex_term_lambda = 1 / (n_samples * PAC_lambda * (1 - PAC_lambda/2)) * \
                       (tot_kld + hyper_kl + math.log(4 * n_train_tasks * math.sqrt(n_samples) / delta))
        complex_term = torch.min(complex_term_quad, complex_term_lambda)

    elif prm.complexity_type == 'PAC_Bayes_variational_role':
        # Dziugaite G K, Hsu K, Gharbieh W, et al. On the role of data in PAC-Bayes bounds[J]. arXiv preprint arXiv:2006.10929, 2020.
        via_eps = 1 / n_samples * (tot_kld + hyper_kl + math.log(
            4 * n_train_tasks * math.sqrt(n_samples) / delta))

        complex_term_quad = via_eps + torch.sqrt(
            via_eps * (via_eps + 2 * empirical_loss))
        complex_term_pinsker = torch.sqrt(via_eps / 2)

        complex_term = torch.min(complex_term_quad, complex_term_pinsker)

    else:
        raise ValueError('Invalid complexity_type')

    return empirical_loss, complex_term