def get_hyper_divergnce(prm, prior_model): # corrected ''' calculates a divergence between hyper-prior and hyper-posterior.... which is, in our case, just a regularization term over the prior parameters ''' # Note: the hyper-prior is N(0, kappa_prior^2 * I) # Note: the hyper-posterior is N(parameters-of-prior-distribution, kappa_post^2 * I) kappa_post = prm.kappa_post kappa_prior = prm.kappa_prior if prm.divergence_type == 'KL': # KLD between hyper-posterior and hyper-prior: norm_sqr = net_weights_magnitude(prior_model, prm, p=2) hyper_dvrg = (norm_sqr + kappa_post**2) / ( 2 * kappa_prior**2) + math.log(kappa_prior / kappa_post) - 1 / 2 elif prm.divergence_type == 'W_NoSqr': d = prior_model.weights_count hyper_dvrg = torch.sqrt( net_weights_magnitude(prior_model, prm, p=2) + d * (kappa_prior - kappa_post)**2) elif prm.divergence_type == 'W_Sqr': d = prior_model.weights_count hyper_dvrg = net_weights_magnitude( prior_model, prm, p=2) + d * (kappa_prior - kappa_post)**2 else: raise ValueError('Invalid prm.divergence_type') assert hyper_dvrg >= 0 return hyper_dvrg
def get_hyper_divergnce(prm, prior_model): ''' calculates a divergence between hyper-prior and hyper-posterior.... which is, in our case, just a regularization term over the prior parameters ''' # Note: the hyper-prior is N(0, kappa_prior^2 * I) # Note: the hyper-posterior is N(parameters-of-prior-distribution, kappa_post^2 * I) # KLD between hyper-posterior and hyper-prior: hyper_dvrg = (1 / (2 * prm.kappa_prior ** 2)) * net_weights_magnitude(prior_model, prm, p=2) return hyper_dvrg
set_model_values(prior_model, prm.prior_mean, prm.prior_log_var) # ------------------------------------------------------------------------------------------- # Plot epochs figure # ------------------------------------------------------------------------------------------- learn_single_Bayes.plot_log(log_mat, prm, val_types_for_show=None, y_axis_lim=[0, 1]) # ------------------------------------------------------------------------------------------- # Analyze the final posterior # ------------------------------------------------------------------------------------------- from Utils.common import net_weights_magnitude weight_norm = torch.sqrt(net_weights_magnitude(post_model, prm)) max_weight = 0.0 for (param_name, param) in post_model.named_parameters(): max_weight = max(max_weight, param.abs().max().item()) print('Final posterior max weight: {}'.format(max_weight)) avg_weight = 0.0 for (param_name, param) in post_model.named_parameters(): avg_weight += param.abs().sum().item() avg_weight /= post_model.weights_count print('Final posterior avg weight: {}'.format(avg_weight)) # mu diff posterior -prior: from Models.stochastic_layers import StochasticLayer