def generate_step(energy, integrator: Langevin = None, ctx=None): sample = 5 * torch.randn(ctx.batch_size, 3, 32, 32, device=ctx.device) levels = torch.arange(0.0, 1.0, 0.01, device=ctx.device) for level in reversed(levels): this_level = level * torch.ones(sample.size(0), device=sample.device) sample = integrator.integrate( ConditionalEnergy(energy, sample, shift=0.025), sample, this_level, None) result = ((sample + 1) / 2).clamp(0, 1) ctx.log(samples=LogImage(result))
def generate_step(energy, base, integrator: Langevin = None, ctx=None): sample = base.sample(ctx.batch_size) levels = torch.zeros(ctx.batch_size, device=sample.device) result = integrator.integrate(energy, sample, levels, None) result = result.clamp(0, 1) ctx.log(samples=LogImage(result))