Пример #1
0
 def __init__(self):
     self.learnable_model = False  # TODO: to implement later
     self.needs_sampler = False
     self.learnable_sampler = False
     self.deviation = lambda x, y: np.sum(
         ((to_tensor(x) - to_tensor(y))**2).detach().numpy())
     self.kernel = lambda d, bw: np.exp(-d / (2 * bw))
     self.bandwidth = 0.01
Пример #2
0
    def __init__(self,
                 variational_samplers,
                 particles,
                 cost_function=None,
                 deviation_statistics=None,
                 biased=False,
                 number_post_samples=8000):
        self.learnable_model = False  #TODO: to implement later
        self.needs_sampler = True
        self.learnable_sampler = True
        self.biased = biased
        self.number_post_samples = number_post_samples
        if cost_function:
            self.cost_function = cost_function
        else:
            self.cost_function = lambda x, y: sum_from_dim(
                (to_tensor(x) - to_tensor(y))**2, dim_index=1)
        if deviation_statistics:
            self.deviation_statistics = deviation_statistics
        else:
            self.deviation_statistics = lambda lst: sum(lst)

        def model_statistics(dic):
            num_samples = list(dic.values())[0].shape[0]
            reassigned_particles = [
                reassign_samples(p._get_sample(num_samples),
                                 source_model=p,
                                 target_model=dic) for p in particles
            ]

            statistics = [
                self.deviation_statistics([
                    self.cost_function(value_pair[0], value_pair[1]).detach().
                    numpy()  #TODO: same as above + GPU
                    for var, value_pair in zip_dict(dic, p).items()
                ]) for p in reassigned_particles
            ]
            return np.array(statistics).transpose()

        truncation_rules = [
            lambda a, idx=index: True if (idx == np.argmin(a)) else False
            for index in range(len(particles))
        ]

        self.sampler_model = [
            truncate_model(model=sampler,
                           truncation_rule=rule,
                           model_statistics=model_statistics)
            for sampler, rule in zip(variational_samplers, truncation_rules)
        ]
Пример #3
0
 def get_particle_loss(self, joint_model, particle_list, sampler_model,
                       number_samples, input_values):
     samples_list = [
         sampler._get_sample(number_samples,
                             input_values=input_values,
                             max_itr=1) for sampler in sampler_model
     ]
     if self.biased:
         importance_weights = [1. / number_samples for _ in sampler_model]
     else:
         importance_weights = [
             joint_model.get_importance_weights(
                 q_samples=samples, q_model=sampler,
                 for_gradient=False).flatten()
             for samples, sampler in zip(samples_list, sampler_model)
         ]
     reassigned_samples_list = [
         reassign_samples(samples,
                          source_model=sampler,
                          target_model=particle) for samples, sampler,
         particle in zip(samples_list, sampler_model, particle_list)
     ]
     pair_list = [
         zip_dict(particle._get_sample(1),
                  samples) for particle, samples in zip(
                      particle_list, reassigned_samples_list)
     ]
     if not self.biased:
         particle_loss = sum([
             torch.sum(
                 to_tensor(w) * self.deviation_statistics([
                     self.cost_function(
                         value_pair[0],
                         value_pair[1].detach())  #TODO: numpy()
                     for var, value_pair in particle.items()
                 ])) for particle, w in zip(pair_list, importance_weights)
         ])
     else:
         particle_loss = sum([
             torch.sum(
                 self.deviation_statistics([
                     self.cost_function(value_pair[0],
                                        value_pair[1].detach())
                     for var, value_pair in particle.items()
                 ])) for particle in pair_list
         ])
     return particle_loss