def get_bayes_task_objective(prm, prior_model, post_model, n_samples, empirical_loss, hyper_kl=0, n_train_tasks=1, noised_prior=True): complexity_type = prm.complexity_type delta = prm.delta # maximal probability that the bound does not hold tot_kld = get_total_kld( prior_model, post_model, prm, noised_prior) # KLD between posterior and sampled prior if complexity_type == 'NoComplexity': # set as zero complex_term = Variable(cmn.zeros_gpu(1), requires_grad=False) elif prm.complexity_type == 'NewBoundMcAllaster': complex_term = torch.sqrt( (1 / (2 * (n_samples - 1))) * (hyper_kl + tot_kld + math.log(2 * n_samples / delta))) elif prm.complexity_type == 'NewBoundSeeger': seeger_eps = (1 / n_samples) * ( tot_kld + hyper_kl + math.log(4 * math.sqrt(n_samples) / delta)) sqrt_arg = 2 * seeger_eps * empirical_loss sqrt_arg = F.relu( sqrt_arg) # prevent negative values due to numerical errors complex_term = 2 * seeger_eps + torch.sqrt(sqrt_arg) elif complexity_type == 'PAC_Bayes_Pentina': complex_term = math.sqrt(1 / n_samples) * tot_kld + hyper_kl * ( 1 / (n_train_tasks * math.sqrt(n_samples))) elif complexity_type == 'Variational_Bayes': # Since we approximate the expectation of the likelihood of all samples, # we need to multiply by the average_loss by total number of samples empirical_loss = n_samples * empirical_loss complex_term = tot_kld # elif complexity_type == 'PAC_Bayes_Seeger': # # Seeger complexity is unique since it requires the empirical loss # # small_num = 1e-9 # to avoid nan due to numerical errors # # seeger_eps = (1 / n_samples) * (kld + math.log(2 * math.sqrt(n_samples) / delta)) # seeger_eps = (1 / n_samples) * (tot_kld + math.log(2 * math.sqrt(n_samples) / delta)) # sqrt_arg = 2 * seeger_eps * task_empirical_loss # sqrt_arg = F.relu(sqrt_arg) # prevent negative values due to numerical errors # complex_term = 2 * seeger_eps + torch.sqrt(sqrt_arg) # elif complexity_type == 'PAC_Bayes_McAllaster': # complex_term = torch.sqrt((1 / (2 * n_samples)) * (tot_kld + math.log(2*math.sqrt(n_samples) / delta))) else: raise ValueError('Invalid complexity_type') return empirical_loss, complex_term
def run_test_majority_vote(model, test_loader, loss_criterion, prm, n_votes=9): # n_test_samples = len(test_loader.dataset) n_test_batches = len(test_loader) model.eval() test_loss = 0 n_correct = 0 for batch_data in test_loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm, is_test=True) batch_size = inputs.shape[ 0] # min(prm.test_batch_size, n_test_samples) info = data_gen.get_info(prm) n_labels = info['n_classes'] votes = cmn.zeros_gpu((batch_size, n_labels)) for i_vote in range(n_votes): outputs = model(inputs) test_loss += loss_criterion(outputs, targets) pred = outputs.data.max( 1, keepdim=True)[1] # get the index of the max output for i_sample in range(batch_size): pred_val = pred[i_sample].cpu().numpy()[0] votes[i_sample, pred_val] += 1 majority_pred = votes.max( 1, keepdim=True)[1] # find argmax class for each sample n_correct += majority_pred.eq( targets.data.view_as(majority_pred)).cpu().sum() test_loss /= n_test_samples test_acc = n_correct / n_test_samples info = { 'test_acc': test_acc, 'n_correct': n_correct, 'test_type': 'majority_vote', 'n_test_samples': n_test_samples, 'test_loss': get_value(test_loss) } return info
def run_test_avg_vote(model, test_loader, loss_criterion, prm, n_votes=5): n_test_samples = len(test_loader.dataset) n_test_batches = len(test_loader) model.eval() test_loss = 0 n_correct = 0 for batch_data in test_loader: inputs, targets = data_gen.get_batch_vars(batch_data, prm, is_test=True) batch_size = min(prm.test_batch_size, n_test_samples) info = data_gen.get_info(prm) n_labels = info['n_classes'] votes = cmn.zeros_gpu((batch_size, n_labels)) for i_vote in range(n_votes): outputs = model(inputs) test_loss += loss_criterion(outputs, targets) votes += outputs.data majority_pred = votes.max(1, keepdim=True)[1] n_correct += majority_pred.eq( targets.data.view_as(majority_pred)).cpu().sum() test_loss /= n_test_samples test_acc = n_correct / n_test_samples info = { 'test_acc': test_acc, 'n_correct': n_correct, 'test_type': 'AvgVote', 'n_test_samples': n_test_samples, 'test_loss': get_value(test_loss) } return info
else: avg_param_vec += parameters_to_vector(curr_model.parameters()) * (1 / n_train_tasks) avg_model = deterministic_models.get_model(prm) vector_to_parameters(avg_param_vec, avg_model.parameters()) # create the prior model: prior_model = stochastic_models.get_model(prm) prior_layers_list = [layer for layer in prior_model.modules() if isinstance(layer, StochasticLayer)] avg_model_layers_list = [layer for layer in avg_model.modules() if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear)] assert len(avg_model_layers_list)==len(prior_layers_list), "lists not equal" for i_layer, prior_layer in enumerate(prior_layers_list): if hasattr(prior_layer, 'w'): prior_layer.w['log_var'] = torch.nn.Parameter(zeros_gpu(1)) prior_layer.w['mean'] = avg_model_layers_list[i_layer].weight if hasattr(prior_layer, 'b'): prior_layer.b['log_var'] = torch.nn.Parameter(zeros_gpu(1)) prior_layer.b['mean'] = avg_model_layers_list[i_layer].bias # save learned prior: f_path = save_model_state(prior_model, save_path) print('Trained prior saved in ' + f_path) elif prm.mode == 'LoadMetaModel': # Loads previously training prior. # First, create the model:
def get_bayes_task_objective(prm, prior_model, post_model, n_samples, empirical_loss, hyper_kl=0, n_train_tasks=1, noised_prior=True): complexity_type = prm.complexity_type delta = prm.delta # maximal probability that the bound does not hold tot_kld = get_total_kld( prior_model, post_model, prm, noised_prior) # KLD between posterior and sampled prior if complexity_type == 'NoComplexity': # set as zero complex_term = Variable(cmn.zeros_gpu(1), requires_grad=False) elif prm.complexity_type == 'NewBoundMcAllaster': # complex_term = torch.sqrt((1 / (2 * (n_samples-1))) * (hyper_kl + tot_kld + math.log(2 * n_samples / delta))) complex_term = torch.sqrt( (hyper_kl + tot_kld + math.log(2 * n_samples * n_train_tasks / delta)) / (2 * (n_samples - 1))) elif prm.complexity_type == 'NewBoundSeeger': seeger_eps = (1 / n_samples) * ( tot_kld + hyper_kl + math.log(4 * math.sqrt(n_samples) / delta)) sqrt_arg = 2 * seeger_eps * empirical_loss sqrt_arg = F.relu( sqrt_arg) # prevent negative values due to numerical errors complex_term = 2 * seeger_eps + torch.sqrt(sqrt_arg) elif complexity_type == 'PAC_Bayes_Pentina': complex_term = math.sqrt(1 / n_samples) * tot_kld + hyper_kl * ( 1 / (n_train_tasks * math.sqrt(n_samples))) elif complexity_type == 'Variational_Bayes': # Since we approximate the expectation of the likelihood of all samples, # we need to multiply by the average_loss by total number of samples empirical_loss = n_samples * empirical_loss complex_term = tot_kld # elif complexity_type == 'PAC_Bayes_Seeger': # # Seeger complexity is unique since it requires the empirical loss # # small_num = 1e-9 # to avoid nan due to numerical errors # # seeger_eps = (1 / n_samples) * (kld + math.log(2 * math.sqrt(n_samples) / delta)) # seeger_eps = (1 / n_samples) * (tot_kld + math.log(2 * math.sqrt(n_samples) / delta)) # sqrt_arg = 2 * seeger_eps * task_empirical_loss # sqrt_arg = F.relu(sqrt_arg) # prevent negative values due to numerical errors # complex_term = 2 * seeger_eps + torch.sqrt(sqrt_arg) # elif complexity_type == 'PAC_Bayes_McAllaster': # complex_term = torch.sqrt((1 / (2 * n_samples)) * (tot_kld + math.log(2*math.sqrt(n_samples) / delta))) #******************************************** New PAC Bayes bound test ********************************************# elif prm.complexity_type == 'PAC_Bayes_quad': # Pérez-Ortiz M, Rivasplata O, Shawe-Taylor J, et al. Tighter risk certificates for neural networks[J]. arXiv preprint arXiv:2007.12911, 2020. quad_eps = (1 / (2 * n_samples)) * (tot_kld + hyper_kl + math.log( 4 * n_train_tasks * math.sqrt(n_samples) / delta)) sqrt_arg = (quad_eps + empirical_loss) * quad_eps # sqrt_arg = F.relu(sqrt_arg) # prevent negative values due to numerical errors complex_term = 2 * (quad_eps + torch.sqrt(sqrt_arg)) elif prm.complexity_type == 'PAC_Bayes_lambda': # Pérez-Ortiz M, Rivasplata O, Shawe-Taylor J, et al. Tighter risk certificates for neural networks[J]. arXiv preprint arXiv:2007.12911, 2020. PAC_lambda = 1 complex_term = 1 / (n_samples * PAC_lambda * (1 - PAC_lambda/2)) * \ (tot_kld + hyper_kl + math.log(4 * n_train_tasks * math.sqrt(n_samples) / delta)) elif prm.complexity_type == 'PAC_Bayes_variational_kl': quad_eps = (1 / (2 * n_samples)) * (tot_kld + hyper_kl + math.log( 4 * n_train_tasks * math.sqrt(n_samples) / delta)) sqrt_arg = (quad_eps + empirical_loss) * quad_eps complex_term_quad = 2 * (quad_eps + torch.sqrt(sqrt_arg)) PAC_lambda = 1 complex_term_lambda = 1 / (n_samples * PAC_lambda * (1 - PAC_lambda/2)) * \ (tot_kld + hyper_kl + math.log(4 * n_train_tasks * math.sqrt(n_samples) / delta)) complex_term = torch.min(complex_term_quad, complex_term_lambda) elif prm.complexity_type == 'PAC_Bayes_variational_role': # Dziugaite G K, Hsu K, Gharbieh W, et al. On the role of data in PAC-Bayes bounds[J]. arXiv preprint arXiv:2006.10929, 2020. via_eps = 1 / n_samples * (tot_kld + hyper_kl + math.log( 4 * n_train_tasks * math.sqrt(n_samples) / delta)) complex_term_quad = via_eps + torch.sqrt( via_eps * (via_eps + 2 * empirical_loss)) complex_term_pinsker = torch.sqrt(via_eps / 2) complex_term = torch.min(complex_term_quad, complex_term_pinsker) else: raise ValueError('Invalid complexity_type') return empirical_loss, complex_term