Beispiel #1
0
 def __init__(self,
              score,
              *args,
              buffer_size=10000,
              buffer_probability=0.95,
              sample_steps=10,
              decay=1,
              reset_threshold=1000,
              integrator=None,
              oos_penalty=True,
              accept_probability=1.0,
              sampler_likelihood=1.0,
              maximum_entropy=0.3,
              **kwargs):
     self.score = ...
     super(EnergyTraining, self).__init__({"score": score}, *args, **kwargs)
     self.sampler_likelihood = sampler_likelihood
     self.maximum_entropy = maximum_entropy
     self.target_score = deepcopy(score).eval()
     self.reset_threshold = reset_threshold
     self.oos_penalty = oos_penalty
     self.decay = decay
     self.integrator = integrator if integrator is not None else Langevin()
     self.sample_steps = sample_steps
     self.buffer = SampleBuffer(self,
                                buffer_size=buffer_size,
                                buffer_probability=buffer_probability,
                                accept_probability=accept_probability)
     self.buffer_loader = lambda x: DataLoader(
         x, batch_size=self.batch_size, shuffle=True, drop_last=True)
Beispiel #2
0
def generate_step(energy, integrator: Langevin = None, ctx=None):
    sample = 5 * torch.randn(ctx.batch_size, 3, 32, 32, device=ctx.device)
    levels = torch.arange(0.0, 1.0, 0.01, device=ctx.device)
    for level in reversed(levels):
        this_level = level * torch.ones(sample.size(0), device=sample.device)
        sample = integrator.integrate(
            ConditionalEnergy(energy, sample, shift=0.025), sample, this_level,
            None)
    result = ((sample + 1) / 2).clamp(0, 1)
    ctx.log(samples=LogImage(result))
Beispiel #3
0
 def __init__(self,
              score,
              *args,
              buffer_size=100,
              buffer_probability=0.9,
              sample_steps=10,
              decay=1,
              integrator=None,
              oos_penalty=True,
              **kwargs):
     self.score = ...
     super(EnergyTraining, self).__init__({"score": score}, *args, **kwargs)
     self.oos_penalty = oos_penalty
     self.decay = decay
     self.integrator = integrator if integrator is not None else Langevin()
     self.sample_steps = sample_steps
     self.buffer = SampleBuffer(self,
                                buffer_size=buffer_size,
                                buffer_probability=buffer_probability)
     self.buffer_loader = lambda x: DataLoader(
         x, batch_size=self.batch_size, shuffle=True, drop_last=True)
Beispiel #4
0
    energy = UNetEnergy()
    critic = ConvCritic()

    training = MNISTEnergyTraining(
        energy,
        critic,
        data,
        network_name="runs/experimental/learned-stein/conv-1",
        device="cuda:0",
        batch_size=64,
        decay=10.0,
        report_interval=1000,
        max_epochs=1000,
        optimizer_kwargs={
            "lr": 1e-4,
            "betas": (0.0, 0.9)
        },
        critic_optimizer_kwargs={
            "lr": 1e-4,
            "betas": (0.0, 0.9)
        },
        n_critic=5,
        integrator=Langevin(rate=-0.01,
                            steps=100,
                            noise=0.01,
                            clamp=(-14, 14),
                            max_norm=None),
        verbose=True)

    training.train()
        samples = torch.cat(samples, dim=-1)
        self.writer.add_image("samples", samples, self.step_id)


if __name__ == "__main__":
    import torch.multiprocessing as mp
    mp.set_start_method("spawn")

    mnist = MNIST("examples/", download=False, transform=ToTensor())
    data = EnergyDataset(mnist)

    score = Convolutional(depth=4)
    energy = MNISTEnergy(SharedModule(score, dynamic=False), keep_rate=0.95)
    integrator = Langevin(rate=10,
                          steps=10,
                          noise=0.01,
                          max_norm=None,
                          clamp=(0, 1))

    training = MNISTEnergyTraining(
        score,
        energy,
        data,
        network_name="off-energy/mnist-conv-off-e-loss-8",
        device="cuda:0",
        integrator=integrator,
        off_energy_weight=0,
        batch_size=64,
        optimizer_kwargs=dict(lr=1e-4, betas=(0.0, 0.999)),
        decay=1.0,
        n_workers=8,
Beispiel #6
0
    def prepare(self):
        the_label = torch.randint(0, 10, (1, ))[0]
        condition = torch.zeros(10)
        condition[the_label] = 1
        data = torch.rand(1, 28, 28)
        return data, condition

    def each_generate(self, data, args):
        samples = [torch.clamp(sample, 0, 1) for sample in data[0:10]]
        samples = torch.cat(samples, dim=-1)
        self.writer.add_image("samples", samples, self.step_id)


if __name__ == "__main__":
    mnist = MNIST("examples/", download=False, transform=ToTensor())
    data = EnergyDataset(mnist)

    energy = Energy()
    integrator = Langevin(rate=30, steps=50, max_norm=None)

    training = MNISTEnergyTraining(energy,
                                   data,
                                   network_name="conditional-mnist-ebm",
                                   device="cpu",
                                   integrator=integrator,
                                   batch_size=64,
                                   max_epochs=1000,
                                   verbose=True)

    training.train()
Beispiel #7
0
        energy,
        data,
        optimizer=torch.optim.Adam,
        optimizer_kwargs=dict(lr=2e-4),
        level_distribution=NormalNoise(lambda t: 1e-3 + t * (5.0 - 1e-3)),
        noise_distribution=LangevinNoise(
            energy, lambda t: 1e-2 * (1e-3 + t * (5.0 - 1e-3))),
        ema_weight=0.9999,
        path=opt.path,
        device=opt.device,
        batch_size=opt.batch_size,
        max_epochs=opt.max_epochs,
        report_interval=opt.report_interval,
        checkpoint_interval=opt.checkpoint_interval)
    training.get_step("tdre_step").extend_update(gradient_action=grad_action)

    # add generating images every few steps:
    integrator = Langevin(rate=-0.1,
                          noise=0.01,
                          steps=5,
                          max_norm=None,
                          clamp=None)
    training.add(generate_step=partial(generate_step,
                                       energy=training.energy_target,
                                       integrator=integrator,
                                       ctx=training),
                 every=opt.report_interval)

    training.load()
    training.train()
Beispiel #8
0
def generate_step(energy, base, integrator: Langevin = None, ctx=None):
    sample = base.sample(ctx.batch_size)
    levels = torch.zeros(ctx.batch_size, device=sample.device)
    result = integrator.integrate(energy, sample, levels, None)
    result = result.clamp(0, 1)
    ctx.log(samples=LogImage(result))
        data = batch.final_state
        samples = [torch.clamp(sample, 0, 1) for sample in data[0:10]]
        samples = torch.cat(samples, dim=-1)
        self.writer.add_image("samples", samples, self.step_id)


if __name__ == "__main__":
    import torch.multiprocessing as mp
    mp.set_start_method("spawn")

    mnist = CIFAR10("examples/", download=True, transform=ToTensor())
    data = EnergyDataset(mnist)

    score = Convolutional(depth=4)
    energy = CIFAR10Energy(SharedModule(score, dynamic=True), keep_rate=0.95)
    integrator = Langevin(rate=50, steps=20, max_norm=None, clamp=(0, 1))

    training = CIFAR10EnergyTraining(
        score,
        energy,
        data,
        network_name="off-energy/cifar10-off-energy-2",
        device="cuda:0",
        integrator=integrator,
        off_energy_weight=5,
        batch_size=64,
        off_energy_decay=1,
        decay=1.0,
        n_workers=8,
        double=True,
        buffer_size=10_000,
Beispiel #10
0
    def each_generate(self, data):
        samples = [(torch.clamp(sample, -1, 1) + 1) / 2
                   for sample in data[0:10]]
        samples = torch.cat(samples, dim=-1)
        self.writer.add_image("samples", samples, self.step_id)


if __name__ == "__main__":
    mnist = CIFAR10("examples/", download=False, transform=ToTensor())
    data = EnergyDataset(mnist)

    energy = ConvEnergy(depth=20)
    integrator = Langevin(rate=1,
                          noise=0.01,
                          steps=20,
                          clamp=None,
                          max_norm=None)

    training = CIFAR10EnergyTraining(
        energy,
        data,
        network_name="classifier-mnist-ebm/cifar-plain",
        device="cuda:0",
        integrator=integrator,
        decay=0.0,
        batch_size=16,
        buffer_size=10000,
        optimizer_kwargs={
            "lr": 1e-4
        },
Beispiel #11
0
        predictions = self.predict(nodes, edges)

        return predictions


if __name__ == "__main__":
    data = EBMNet(sys.argv[1], num_neighbours=15, N=128)
    net = SDP(
        MaterializedEnergy(pair_depth=4,
                           size=128,
                           value_size=16,
                           kernel_size=1,
                           drop=0.0))
    integrator = Langevin(rate=500.0,
                          noise=1.0,
                          steps=50,
                          max_norm=None,
                          clamp=None)
    training = EBMTraining(net,
                           data,
                           batch_size=32,
                           decay=1.0,
                           max_epochs=5000,
                           integrator=integrator,
                           buffer_probability=0.95,
                           buffer_size=10000,
                           optimizer_kwargs={
                               "lr": 1e-3,
                               "betas": (0.0, 0.99)
                           },
                           device="cuda:0",