def __init__(self): self.learnable_model = False # TODO: to implement later self.needs_sampler = False self.learnable_sampler = False self.deviation = lambda x, y: np.sum( ((to_tensor(x) - to_tensor(y))**2).detach().numpy()) self.kernel = lambda d, bw: np.exp(-d / (2 * bw)) self.bandwidth = 0.01
def __init__(self, variational_samplers, particles, cost_function=None, deviation_statistics=None, biased=False, number_post_samples=8000): self.learnable_model = False #TODO: to implement later self.needs_sampler = True self.learnable_sampler = True self.biased = biased self.number_post_samples = number_post_samples if cost_function: self.cost_function = cost_function else: self.cost_function = lambda x, y: sum_from_dim( (to_tensor(x) - to_tensor(y))**2, dim_index=1) if deviation_statistics: self.deviation_statistics = deviation_statistics else: self.deviation_statistics = lambda lst: sum(lst) def model_statistics(dic): num_samples = list(dic.values())[0].shape[0] reassigned_particles = [ reassign_samples(p._get_sample(num_samples), source_model=p, target_model=dic) for p in particles ] statistics = [ self.deviation_statistics([ self.cost_function(value_pair[0], value_pair[1]).detach(). numpy() #TODO: same as above + GPU for var, value_pair in zip_dict(dic, p).items() ]) for p in reassigned_particles ] return np.array(statistics).transpose() truncation_rules = [ lambda a, idx=index: True if (idx == np.argmin(a)) else False for index in range(len(particles)) ] self.sampler_model = [ truncate_model(model=sampler, truncation_rule=rule, model_statistics=model_statistics) for sampler, rule in zip(variational_samplers, truncation_rules) ]
def get_particle_loss(self, joint_model, particle_list, sampler_model, number_samples, input_values): samples_list = [ sampler._get_sample(number_samples, input_values=input_values, max_itr=1) for sampler in sampler_model ] if self.biased: importance_weights = [1. / number_samples for _ in sampler_model] else: importance_weights = [ joint_model.get_importance_weights( q_samples=samples, q_model=sampler, for_gradient=False).flatten() for samples, sampler in zip(samples_list, sampler_model) ] reassigned_samples_list = [ reassign_samples(samples, source_model=sampler, target_model=particle) for samples, sampler, particle in zip(samples_list, sampler_model, particle_list) ] pair_list = [ zip_dict(particle._get_sample(1), samples) for particle, samples in zip( particle_list, reassigned_samples_list) ] if not self.biased: particle_loss = sum([ torch.sum( to_tensor(w) * self.deviation_statistics([ self.cost_function( value_pair[0], value_pair[1].detach()) #TODO: numpy() for var, value_pair in particle.items() ])) for particle, w in zip(pair_list, importance_weights) ]) else: particle_loss = sum([ torch.sum( self.deviation_statistics([ self.cost_function(value_pair[0], value_pair[1].detach()) for var, value_pair in particle.items() ])) for particle in pair_list ]) return particle_loss