def __init__( self, variational_samplers, particles, cost_function=None, deviation_statistics=None, biased=False, number_post_samples=20000, gradient_estimator=gradient_estimators.PathwiseDerivativeEstimator ): self.gradient_estimator = gradient_estimator self.learnable_posterior = True self.learnable_model = False #TODO: to implement later self.needs_sampler = True self.learnable_sampler = True self.biased = biased self.number_post_samples = number_post_samples if cost_function: self.cost_function = cost_function else: self.cost_function = lambda x, y: sum_from_dim( (x - y)**2, dim_index=1) if deviation_statistics: self.deviation_statistics = deviation_statistics else: self.deviation_statistics = lambda lst: sum(lst) def model_statistics(dic): num_samples = list(dic.values())[0].shape[0] reassigned_particles = [ reassign_samples(p._get_sample(num_samples), source_model=p, target_model=dic) for p in particles ] statistics = [ self.deviation_statistics([ self.cost_function(value_pair[0].detach().numpy(), value_pair[1].detach().numpy()) for var, value_pair in zip_dict(dic, p).items() ]) for p in reassigned_particles ] return np.array(statistics).transpose() truncation_rules = [ lambda a, idx=index: True if (idx == np.argmin(a)) else False for index in range(len(particles)) ] self.sampler_model = [ truncate_model(model=sampler, truncation_rule=rule, model_statistics=model_statistics) for sampler, rule in zip(variational_samplers, truncation_rules) ]
import matplotlib.pyplot as plt from brancher.variables import ProbabilisticModel from brancher.standard_variables import NormalVariable, LogNormalVariable from brancher.transformations import truncate_model from brancher.visualizations import plot_density # Normal model mu = NormalVariable(0., 1., "mu") x = NormalVariable(mu, 0.1, "x") model = ProbabilisticModel([x]) # decision rule model_statistics = lambda dic: dic[x].data truncation_rule = lambda a: ((a > 0.5) & (a < 0.6)) | ((a > -0.6) & (a < -0.5)) # Truncated model truncated_model = truncate_model(model, truncation_rule, model_statistics) plot_density(truncated_model, variables=["mu", "x"], number_samples=10000) plt.show()