def run_sliced_score(score, data, args, noise_level=0.0): if noise_level: data = gaussian_noise(data, noise_level) make_differentiable(data) score_val = score(data, noise_level, *args) loss = sliced_score(score_val, data) return loss, namespace(data=data, score=score_val, noise_level=noise_level)
def forward(self, data, condition): dist = Normal(self.mean, self.logv.exp()) log_p = dist.log_prob(data) log_p = log_p.view(*log_p.shape[:-3], -1) return log_p.sum(dim=-1, keepdim=True), namespace( distribution=dist )
def run_denoising_score(score, data, args, noise_level): noised = gaussian_noise(data, noise_level) score_val = score(noised, noise_level, *args) loss = denoising_score(score_val, data, noised, noise_level) return loss, namespace(data=data, noised=noised, score=score_val, noise_level=noise_level)
def run_density_ratio(energy, sample, data, args): real_data = data fake_data = sample real, fake = energy(real_data, args), energy(fake_data, args) loss = density_ratio_estimation(real, fake) return loss, namespace(real_data=real_data, fake_data=fake_data, real=real, fake=fake)
def run_generator(generator, discriminator, gan_loss=non_saturating, gan_loss_kwargs=None, ctx=None): fake_data = generator.sample(ctx.batch_size) fake = discriminator(fake_data) loss = gan_loss(ctx=ctx, **gan_loss_kwargs).generator(fake) return loss, namespace( fake_data=fake_data, fake=fake, ctx=ctx )
def run_discriminator(discriminator, real_data, fake_data, gan_loss=non_saturating, gan_loss_kwargs=None, ctx=None): real, fake = discriminator(real_data), discriminator(fake_data) loss = gan_loss(ctx=ctx, **(gan_loss_kwargs or {})).critic(real, fake) return loss, namespace( real_data=real_data, fake_data=fake_data, real=real, fake=fake, ctx=ctx )
def run_tdre(energy, base, data, args, mixing=None): real_data, fake_data, levels, _ = mixing(data, base) real, fake = energy(real_data, levels, args), energy(fake_data, levels, args) level_losses = density_ratio_estimation(real, fake) return level_losses.mean(), namespace(real_data=real_data, fake_data=fake_data, real=real, fake=fake, levels=levels, level_losses=level_losses)
def finite_difference_score_serial(score, data, args, noise_level=0.0, eps=1e-3): if noise_level: data = gaussian_noise(data, noise_level) data_p, data_m, v = finite_difference_input(data, eps=eps, parallel=False) s_p = score(data_p, noise_level, *args) s_m = score(data_m, noise_level, *args) loss = finite_difference_score(s_p, s_m, v, eps=eps) return loss, namespace(s_p=s_p, s_m=s_m, v=v)
def finite_difference_score_parallel(score, data, args, noise_level=0.0, eps=1e-3): if noise_level: data = gaussian_noise(data, noise_level) data, v = finite_difference_input(data, eps=eps) score_val = score(data, _replicate_aux(noise_level), *_replicate_aux(args)) s_p, s_m = _split_aux(score_val) loss = finite_difference_score(s_p, s_m, v, eps=eps) return loss, namespace(s_p=s_p, s_m=s_m, v=v)
def filter_kwargs(kwargs, **targets): result = {} for name, target in targets.items(): result[name] = {} target_kwargs, has_kwargs = get_kwargs(target) if has_kwargs: result[name] = kwargs else: for key in target_kwargs: if key in kwargs: result[name][key] = kwargs[key] return namespace(**result)
def run_diffusion_recovery_likelihood(energy, base, data, args, integrator=None, mixing=None, conditional=None): real_data, condition, levels, _ = mixing(data, base) conditional_energy = conditional(energy, condition) fake_data = integrator.integrate(conditional_energy, condition, args) real, fake = energy(real_data, levels, args), energy(fake_data, levels, args) loss = real.mean() - fake.mean() return loss, namespace( real_data=real_data, fake_data=fake_data, condition=condition, real=real, fake=fake )
def run_tnce(energy, base, data, args, mixing=None, noise_contrastive=probability_surface_estimation): real_data, fake_data, real_levels, fake_levels = mixing(data, base) real = energy(real_data, real_levels, args) fake = energy(fake_data, real_levels, args) real_base = energy(real_data, fake_levels, args) fake_base = energy(fake_data, fake_levels, args) is_base = (fake_levels == 1.0)[:, None] base_real = base.log_prob(real_data, args) base_fake = base.log_prob(fake_data, args) real_base = (~is_base).float() * real_base + is_base.float() * base_real fake_base = (~is_base).float() * fake_base + is_base.float() * base_fake level_losses = noise_contrastive(real, fake, real_base, fake_base) return level_losses.mean(), namespace(real_data=real_data, fake_data=fake_data, real=real, fake=fake, level_losses=level_losses, real_levels=real_levels, fake_levels=fake_levels)
def maximum_likelihood_step(model, data, ctx=None): data, condition = data.sample(ctx.batch_size) log_p, args = model(data, condition) ctx.argmax(log_likelihood=log_p) return namespace(data=data, condition=condition, **args.asdict())
def forward(self, data, condition): distribution = self.predictor(condition) return distribution.log_prob(data), namespace( distribution=distribution)