def compute_kale(self, data_loader, base_loader, precomputed_stats=None): self.discriminator.eval() base_mean = torch.tensor(0.).to(self.device) data_mean = 0 if precomputed_stats is None: M = 0 with torch.no_grad(): for img in base_loader: energy = -self.discriminator(img.to(self.device)) if self.args.criterion == 'donsker': base_mean, M = cp.iterative_log_sum_exp( torch.exp(energy), base_mean, M) else: energy = -torch.exp(energy - self.log_partition) base_mean, M = cp.iterative_mean(energy, base_mean, M) if self.args.criterion == 'donsker': log_partition = 1. * base_mean - np.log(M) base_mean = torch.tensor(-1.).to(self.device) else: log_partition = self.log_partition else: base_mean, log_partition = precomputed_stats M = 0 for data, target in data_loader: with torch.no_grad(): data_energy = -(self.discriminator(data.to(self.device)) + log_partition) data_mean, M = cp.iterative_mean(data_energy, data_mean, M) KALE = data_mean + base_mean + 1 return KALE, base_mean, log_partition
def log_partition(self, N): gen_data_in = self.sample(None, N) out = -0.5 * torch.norm(gen_data_in, dim=1)**2 + self.sampler.potential(gen_data_in) M = 0 log_partition = torch.tensor(0.).to(self.device) log_partition, M = cp.iterative_log_sum_exp(out, log_partition, M) log_partition = -log_partition + np.log( M) + 0.5 * gen_data_in.shape[1] * np.log(2. * np.pi) return log_partition
def init_log_partition(self): log_partition = torch.tensor(0.).to(self.device) M = 0 num_batches = 100 self.generator.eval() self.discriminator.eval() for batch_idx in range(num_batches): with torch.no_grad(): Z = self.noise_gen.sample([self.args.sample_b_size]) fake_data = self.generator(Z) fake_data = -self.discriminator(fake_data) log_partition, M = cp.iterative_log_sum_exp( fake_data, log_partition, M) log_partition = log_partition - np.log(M) return torch.tensor(log_partition.item()).to(self.device)