def model_statistics(dic): num_samples = list(dic.values())[0].shape[0] reassigned_particles = [reassign_samples(p._get_sample(num_samples), source_model=p, target_model=dic) for p in particles] statistics = [self.deviation_statistics([self.cost_function(value_pair[0].detach().cpu().numpy(), value_pair[1].detach().cpu().numpy()) for var, value_pair in zip_dict(dic, p).items()]) for p in reassigned_particles] return np.array(statistics).transpose()
def get_particle_loss(self, joint_model, particle_list, sampler_model, number_samples, input_values): samples_list = [ sampler._get_sample(number_samples, input_values=input_values, max_itr=1) for sampler in sampler_model ] if self.biased: importance_weights = [1. / number_samples for _ in sampler_model] else: importance_weights = [ joint_model.get_importance_weights( q_samples=samples, q_model=sampler, for_gradient=False).flatten() for samples, sampler in zip(samples_list, sampler_model) ] reassigned_samples_list = [ reassign_samples(samples, source_model=sampler, target_model=particle) for samples, sampler, particle in zip(samples_list, sampler_model, particle_list) ] pair_list = [ zip_dict(particle._get_sample(1), samples) for particle, samples in zip( particle_list, reassigned_samples_list) ] if not self.biased: particle_loss = sum([ torch.sum( to_tensor(w) * self.deviation_statistics([ self.cost_function( value_pair[0], value_pair[1].detach()) #TODO: numpy() for var, value_pair in particle.items() ])) for particle, w in zip(pair_list, importance_weights) ]) else: particle_loss = sum([ torch.sum( self.deviation_statistics([ self.cost_function(value_pair[0], value_pair[1].detach()) for var, value_pair in particle.items() ])) for particle in pair_list ]) return particle_loss