示例#1
0
        def model_statistics(dic):
            num_samples = list(dic.values())[0].shape[0]
            reassigned_particles = [reassign_samples(p._get_sample(num_samples), source_model=p, target_model=dic)
                                    for p in particles]

            statistics = [self.deviation_statistics([self.cost_function(value_pair[0].detach().cpu().numpy(),
                                                                        value_pair[1].detach().cpu().numpy())
                                                     for var, value_pair in zip_dict(dic, p).items()])
                          for p in reassigned_particles]
            return np.array(statistics).transpose()
示例#2
0
 def get_particle_loss(self, joint_model, particle_list, sampler_model,
                       number_samples, input_values):
     samples_list = [
         sampler._get_sample(number_samples,
                             input_values=input_values,
                             max_itr=1) for sampler in sampler_model
     ]
     if self.biased:
         importance_weights = [1. / number_samples for _ in sampler_model]
     else:
         importance_weights = [
             joint_model.get_importance_weights(
                 q_samples=samples, q_model=sampler,
                 for_gradient=False).flatten()
             for samples, sampler in zip(samples_list, sampler_model)
         ]
     reassigned_samples_list = [
         reassign_samples(samples,
                          source_model=sampler,
                          target_model=particle) for samples, sampler,
         particle in zip(samples_list, sampler_model, particle_list)
     ]
     pair_list = [
         zip_dict(particle._get_sample(1),
                  samples) for particle, samples in zip(
                      particle_list, reassigned_samples_list)
     ]
     if not self.biased:
         particle_loss = sum([
             torch.sum(
                 to_tensor(w) * self.deviation_statistics([
                     self.cost_function(
                         value_pair[0],
                         value_pair[1].detach())  #TODO: numpy()
                     for var, value_pair in particle.items()
                 ])) for particle, w in zip(pair_list, importance_weights)
         ])
     else:
         particle_loss = sum([
             torch.sum(
                 self.deviation_statistics([
                     self.cost_function(value_pair[0],
                                        value_pair[1].detach())
                     for var, value_pair in particle.items()
                 ])) for particle in pair_list
         ])
     return particle_loss