示例#1
0
 def integrate(self, score, data, *args):
     done = False
     count = 0
     step_count = self.steps if self.step > 0 else 10 * self.steps
     while not done:
         make_differentiable(data)
         make_differentiable(args)
         energy = score(data + self.noise * torch.randn_like(data), *args)
         if isinstance(energy, (list, tuple)):
             energy, *_ = energy
         gradient = ag.grad(energy, data, torch.ones_like(energy))[0]
         if self.max_norm:
             gradient = clip_grad_by_norm(gradient, self.max_norm)
         data = data - self.rate * gradient
         if self.clamp is not None:
             data = data.clamp(*self.clamp)
         data = data.detach()
         done = count >= step_count
         if self.target is not None:
             done = done and bool((energy.mean(dim=0) <= self.target).all())
         count += 1
         if (count + 1) % 500 == 0:
             data.random_()
     self.step += 1
     return data
示例#2
0
    def step(self, score, data, *args):
        make_differentiable(data)
        make_differentiable(args)
        # data = ...
        if isinstance(data, (list, tuple)):
            data = [
                item + noise * torch.randn_like(item)
                for noise, item in zip(self.noise, data)
            ]
        else:
            data = data + self.noise * torch.randn_like(data)
        energy = score(data, *args)
        if isinstance(energy, (list, tuple)):
            energy, *_ = energy

        gradient = ag.grad(energy, data, torch.ones_like(energy))
        if isinstance(data, (list, tuple)):
            data = list(data)
            for idx, (rate, clamp, gradval) in enumerate(
                    zip(self.rate, self.clamp, gradient)):
                data[idx] = data[idx] - rate * gradval
                if clamp is not None:
                    data[idx] = data[idx].clamp(*clamp)
        else:
            gradient = gradient[0]
            if self.max_norm:
                gradient = clip_grad_by_norm(gradient, self.max_norm)
            data = data - self.rate * gradient
            if self.clamp is not None:
                data = data.clamp(*self.clamp)
        return data
示例#3
0
    def integrate(self, score, data, *args):
        data = data.clone()
        current_energy, *_ = score(data, *args)
        for idx in range(self.steps):
            make_differentiable(data)
            make_differentiable(args)

            energy = score(data, *args)
            if isinstance(energy, (list, tuple)):
                energy, *_ = energy

            gradient = ag.grad(energy, data.tensor, torch.ones_like(energy))[0]
            if self.max_norm:
                gradient = clip_grad_by_norm(gradient, self.max_norm)

            # attempt at gradient based local update of discrete variables:
            grad_prob = (-500 * gradient).softmax(dim=1)
            new_prob = self.noise + self.rate * grad_prob + (
                1 - self.noise - self.rate) * data.tensor
            new_val = hard_one_hot(new_prob.log())
            data.tensor = new_val

            data = data.detach()

        return data
    def forward(self, inputs, noise, *args):
        with torch.enable_grad():
            make_differentiable(inputs)

            cond = torch.zeros(inputs.size(0),
                               10,
                               dtype=inputs.dtype,
                               device=inputs.device)
            offset = (torch.log(noise) / torch.log(torch.tensor(0.60))).long()
            cond[torch.arange(inputs.size(0)), offset.view(-1)] = 1
            out = self.preprocess(inputs)
            count = 0
            for bn, proj, block in zip(self.bn, self.project, self.blocks):
                out = func.elu(bn(proj(out) + block(out), cond))
                count += 1
                if count % 5 == 0:
                    out = func.avg_pool2d(out, 2)
            out = self.postprocess(out)
            out = func.adaptive_avg_pool2d(out, 1).view(-1, 128)
            logits = self.predict(out)
            energy = -logits.logsumexp(dim=1)
            score = -torch.autograd.grad(energy,
                                         inputs,
                                         torch.ones_like(energy),
                                         create_graph=True,
                                         retain_graph=True)[0]
            return score, logits
示例#5
0
    def mutate(self, score, data, *args):
        result = data.clone()

        make_differentiable(result)
        make_differentiable(args)
        energy = score(result, *args)
        gradient = ag.grad(energy, result,
                           torch.ones(*energy.shape, device=result.device))[0]

        # position choice
        position_gradient = -gradient.sum(dim=1)
        position_distribution = torch.distributions.Categorical(
            logits=position_gradient)
        position_proposal = position_distribution.sample()

        # change choice
        change_gradient = -gradient[torch.arange(0, gradient.size(0)), :,
                                    position_proposal]
        change_distribution = torch.distributions.Categorical(
            logits=change_gradient)
        change_proposal = change_distribution.sample()

        # mutate:
        result[torch.arange(0, result.size(0)), :, position_proposal] = 0
        result[torch.arange(0, result.size(0)), change_proposal,
               position_proposal] = 1

        return result.detach()
示例#6
0
    def run_energy(self, data):
        data, labels = data

        make_differentiable(data)
        input_data, *data_args = self.data_key(data)
        real_logits = self.score(input_data, *data_args)

        # sample after first pass over real data, to catch
        # possible batch-norm shenanigans without blowing up.
        fake = self.sample()

        if self.step_id % self.report_interval == 0:
            detached, *args = self.data_key(to_device(fake, "cpu"))
            self.each_generate(detached.detach(), *args)

        make_differentiable(fake)
        input_fake, *fake_args = self.data_key(fake)
        fake_logits = self.score(input_fake, *fake_args)

        real_result = self.logit_energy(real_logits, *data_args)
        fake_result = self.logit_energy(fake_logits, *fake_args)

        # set integrator target, if appropriate.
        if self.integrator.target is None:
            self.integrator.target = real_result.detach().mean(dim=0)
        self.integrator.target = 0.6 * self.integrator.target + 0.4 * real_result.detach(
        ).mean(dim=0)
        return real_result, fake_result, real_logits, labels
示例#7
0
def run_sliced_score(score, data, args, noise_level=0.0):
    if noise_level:
        data = gaussian_noise(data, noise_level)
    make_differentiable(data)
    score_val = score(data, noise_level, *args)
    loss = sliced_score(score_val, data)
    return loss, namespace(data=data, score=score_val, noise_level=noise_level)
示例#8
0
 def run_discriminator(self, data):
     with torch.no_grad():
         _, fake_batches, _ = self.run_generator(data)
     make_differentiable(fake_batches)
     make_differentiable(data)
     fake_result = self.discriminator(fake_batches[0])
     real_result = self.discriminator(data)
     return fake_batches[0], data, fake_result, real_result
示例#9
0
 def run_energy(self, data):
     data, *args = self.data_key(data)
     noisy, sigma = self.noise(data)
     make_differentiable(noisy)
     result = self.score(noisy, sigma, *args)
     return (
         result,  #.view(result.size(0), -1),
         data,  #.view(result.size(0), -1),
         noisy,  #.view(result.size(0), -1),
         sigma,  #.view(result.size(0), -1)
     )
示例#10
0
 def run_discriminator(self, data):
     with torch.no_grad():
         fake = self.run_generator(data)
     make_differentiable(fake)
     make_differentiable(data)
     _, fake_batch, _, _ = fake
     inputs, available, requested = data
     fake_result = self._run_discriminator_aux(inputs, fake_batch,
                                               available, requested)
     real_result = self._run_discriminator_aux(inputs, inputs, available,
                                               requested)
     return fake, data, fake_result, real_result
示例#11
0
    def sample(self):
        buffer_iter = iter(self.buffer_loader(self.buffer))
        data, *args = to_device(self.data_key(next(buffer_iter)), self.device)
        data = self.integrator.integrate(self.score, data, *args).detach()
        detached = data.detach().cpu()
        update = (to_device((detached[idx], *[arg[idx]
                                              for arg in args]), "cpu")
                  for idx in range(data.size(0)))
        make_differentiable(update, toggle=False)
        self.buffer.update(update)

        return to_device((detached, *args), self.device)
示例#12
0
    def sample(self):
        buffer_iter = iter(self.buffer_loader(self.buffer))
        data, *args = to_device(self.data_key(next(buffer_iter)), self.device)
        self.score.eval()
        data = detach(self.integrator.integrate(self.score, data, *args))
        self.score.train()
        detached = to_device(data, "cpu")
        make_differentiable(args, toggle=False)
        update = self.decompose_batch(detached, *args)
        make_differentiable(update, toggle=False)
        self.buffer.update(update)

        return to_device((detached, *args), self.device)
示例#13
0
    def sample(self):
        buffer_iter = iter(self.buffer_loader(self.buffer))
        data, *args = to_device(self.data_key(next(buffer_iter)), self.device)
        self.sampler.eval()
        with torch.no_grad():
            data = self.integrate(data, *args)
        self.sampler.train()
        detached = to_device(data.detach(), "cpu")
        update = self.decompose_batch(detached, *args)
        make_differentiable(update, toggle=False)
        self.buffer.update(update)

        return to_device((detached, *args), self.device)
示例#14
0
 def integrate(self, score, data, *args):
     for idx in range(self.steps):
         make_differentiable(data)
         make_differentiable(args)
         energy, *_ = score(data + self.noise * torch.randn_like(data),
                            *args)
         gradient = ag.grad(energy, data,
                            torch.ones(*energy.shape,
                                       device=data.device))[0]
         if self.max_norm:
             gradient = clip_grad_by_norm(gradient, self.max_norm)
         data = self.update(data - self.rate * gradient)
     return data
 def run_regularizer(self, data):
     make_differentiable(data)
     codes = self.encode(data)
     discriminator_result = self.discriminator(data, *codes)
     gradient = torch.autograd.grad(
         discriminator_result,
         self.mixing_key(data),
         grad_outputs=torch.ones_like(discriminator_result),
         create_graph=True,
         retain_graph=True)
     gradient = gradient.view(gradient.size(0), -1)
     gradient = (gradient**2).sum(dim=1)
     return gradient
示例#16
0
    def run_energy(self, real, fake, args):
        make_differentiable(real)
        real_result = self.score(real, *args)
        grad = ag.grad(real_result,
                       real,
                       grad_outputs=torch.ones_like(real_result),
                       create_graph=True,
                       retain_graph=True)[0]
        grad_norm = (grad.view(grad.size(0), -1)**2).sum(dim=1)

        fake_result = self.score(fake, *args)

        return real_result, fake_result, grad_norm
示例#17
0
def energy_score_aux(energy, grad_args, *args, **kwargs):
    grad_vars = _select(args, kwargs, grad_args)
    make_differentiable(grad_vars)
    E = energy(*args, **kwargs)
    score = grad(
        E,
        grad_vars,
        # E = -log p + C -> score = -grad E
        grad_outputs=-torch.ones_like(E),
        create_graph=True)
    if len(score) == 1:
        score = score[0]
    return score
示例#18
0
 def integrate(self, score, data, *args):
     for idx in range(self.steps):
         make_differentiable(data)
         make_differentiable(args)
         energy = score(data + self.noise * torch.randn_like(data), *args)
         if isinstance(energy, (list, tuple)):
             energy, *_ = energy
         gradient = ag.grad(energy, data, torch.ones_like(energy))[0]
         if self.max_norm:
             gradient = clip_grad_by_norm(gradient, self.max_norm)
         data = data - self.rate * gradient
         if self.clamp is not None:
             data = data.clamp(*self.clamp)
     return data
示例#19
0
    def run_energy(self, data):
        fake = self.sample()

        if self.step_id % self.report_interval == 0:
            detached, *args = self.data_key(to_device(fake, "cpu"))
            self.each_generate(detached.detach(), *args)

        make_differentiable(fake)
        make_differentiable(data)
        input_data, *data_args = self.data_key(data)
        input_fake, *fake_args = self.data_key(fake)
        real_result = self.score(input_data, *data_args)
        fake_result = self.score(input_fake, *fake_args)
        return real_result, fake_result
 def forward(self, inputs, sigma, *args):
     with torch.enable_grad():
         make_differentiable(inputs)
         processed = self.process(inputs.view(inputs.size(0), -1))
         condition = self.condition(sigma.view(sigma.size(0), -1))
         scale = self.scale(condition)
         bias = self.bias(condition)
         logits = self.predict(processed * scale + bias)
         energy = -logits.logsumexp(dim=1)
         score = torch.autograd.grad(energy,
                                     inputs,
                                     grad_outputs=torch.ones_like(energy),
                                     retain_graph=True,
                                     create_graph=True)[0]
     return score, logits
示例#21
0
 def integrate(self, score, data, *args):
     for idx in range(self.steps):
         make_differentiable(data)
         make_differentiable(args)
         energy = score(data, *args)
         if isinstance(energy, (list, tuple)):
             energy, *_ = energy
         gradient = self.gradient_factor * ag.grad(
             energy, data, torch.ones_like(energy))[0]
         noise = self.noise * torch.randn_like(
             data) if self.take_noise else 0.0
         data = data - self.noise**2 / 2 * gradient + noise
         if self.clamp is not None:
             data = data.clamp(*self.clamp)
     return data
示例#22
0
    def run_energy(self, data):
        make_differentiable(data)
        input_data, *data_args = self.data_key(data)
        real_result = self.score(input_data, *data_args)

        # sample after first pass over real data, to catch
        # possible batch-norm shenanigans without blowing up.
        fake = self.sample()

        if self.step_id % self.report_interval == 0:
            detached, *args = self.data_key(to_device(fake, "cpu"))
            self.each_generate(detached.detach(), *args)

        make_differentiable(fake)
        input_fake, *fake_args = self.data_key(fake)
        #self.score.eval()
        fake_result = self.score(input_fake, *fake_args)
        #self.score.train()
        return real_result, fake_result
示例#23
0
  def contrastive_step(self, data):
    """Performs a single step of contrastive training.

    Args:
      data: data points used for training.
    """
    if self.step_id % self.report_interval == 0:
      self.visualize(data)

    self.optimizer.zero_grad()
    data = to_device(data, self.device)
    make_differentiable(data)
    args = self.run_networks(data)
    loss_val = self.loss(*args)

    self.log_statistics(loss_val, name="total loss")

    loss_val.backward()
    self.optimizer.step()
示例#24
0
    def run_energy(self, data):
        fake = self.sample()
        if self.oos_penalty:
            oos = self.out_of_sample()

        if self.step_id % self.report_interval == 0:
            detached, *args = self.data_key(fake)
            self.each_generate(detached, *args)

        make_differentiable(fake)
        make_differentiable(data)
        if self.oos_penalty:
            make_differentiable(oos)
        input_data, reference_data, *data_args = self.data_key(data)
        input_fake, reference_fake, *fake_args = self.data_key(fake)
        if self.oos_penalty:
            input_oos, reference_oos, *oos_args = self.data_key(oos)
        real_result, real_parameters = self.score(input_data, reference_data,
                                                  *data_args)
        fake_result, fake_parameters = self.score(input_fake, reference_fake,
                                                  *fake_args)
        oos_result = None
        oos_parameters = None
        if self.oos_penalty:
            oos_result, oos_parameters = self.score(input_oos, reference_oos,
                                                    *oos_args)
        return real_result, fake_result, oos_result, real_parameters, fake_parameters, oos_parameters
示例#25
0
    def run_energy(self, data):
        make_differentiable(data)
        input_data, *data_args = self.data_key(data)
        real_result = self.score(input_data, *data_args)

        # sample after first pass over real data, to catch
        # possible batch-norm shenanigans without blowing up.
        fake = self.sample()

        if self.step_id % self.report_interval == 0:
            detached, *args = self.data_key(to_device(fake, "cpu"))
            self.each_generate(detach(detached), *args)

        make_differentiable(fake)
        input_fake, *fake_args = self.data_key(fake)

        self.score.eval()
        fake_update_result = None
        if self.sampler_likelihood:
            # Sampler log likelihood:
            fake_result = self.score(input_fake, *fake_args)
            fake_update = self.integrator.step(self.score, input_fake,
                                               *fake_args)
            self.update_target()
            fake_update_result = self.target_score(fake_update, *fake_args)

        comparison = None
        if self.maximum_entropy:
            # Sampler entropy:
            compare_update = fake_update
            compare_target = to_device(
                self.data_key(self.buffer.sample(self.batch_size))[0],
                compare_update.device)
            if hasattr(self.score, "embed"):
                compare_update = self.target_score.embedding(compare_update)
                compare_target = self.target_score.embedding(compare_target)
            comparison = self.compare(compare_update, compare_target.detach())
        self.score.train()
        return real_result, fake_result, fake_update_result, comparison
    def energy_loss(self, score, data, noisy, sigma):
        vectors = self.noise_vectors(score)
        make_differentiable(vectors)

        grad_v = (score * vectors).view(score.size(0), -1).sum()
        jacobian = torch.autograd.grad(grad_v,
                                       noisy.tensor,
                                       retain_graph=True,
                                       create_graph=True)[0]
        # jacograd = torch.autograd.grad(jacobian.mean(), noisy.tensor, retain_graph=True)
        # print(jacobian)
        # print(jacograd)

        norm = (score**2).view(score.size(0), -1).sum(dim=-1) / 2
        jacobian = (vectors * jacobian).view(score.size(0), -1).sum(dim=-1)

        result = (norm + jacobian) * sigma.view(score.size(0), -1)**2

        result = result.mean()

        self.current_losses["ebm"] = float(result)

        return result
示例#27
0
    def integrate(self, score, data, *args):
        data = data.clone()
        result = data.clone()
        current_energy = score(data, *args)
        for idx in range(self.steps):
            make_differentiable(data)
            make_differentiable(args)

            energy, deltas = score(data, *args, return_deltas=True)

            # attempt at gradient based local update of discrete variables:
            grad_prob = torch.zeros_like(deltas)
            grad_prob[torch.arange(deltas.size(0)), deltas.argmax(dim=1)] = 1
            if self.scale is not None:
                grad_prob = (self.scale * deltas).softmax(dim=1)
            new_prob = self.noise + self.rate * grad_prob + (
                1 - self.noise - self.rate) * data.tensor
            new_val = hard_one_hot(new_prob.log())
            data.tensor = new_val

            data = data.detach()

        return data
示例#28
0
    def run_discriminator(self, data):
        with torch.no_grad():
            _, (fake_fw, fake_rv), _ = self.run_generator(data)
        make_differentiable(fake_fw)
        make_differentiable(fake_rv)
        make_differentiable(data)
        real_result_fw = self.fw_discriminator(data[1])
        fake_result_fw = self.fw_discriminator(fake_fw)
        real_result_rv = self.rv_discriminator(data[0])
        fake_result_rv = self.rv_discriminator(fake_rv)

        real_result = (real_result_fw, real_result_rv)
        fake_result = (fake_result_fw, fake_result_rv)
        fake_batch = fake_fw, fake_rv
        real_batch = (data[1], data[0])

        return fake_batch, real_batch, fake_result, real_result
示例#29
0
 def run_energy(self, data):
     data, *args = self.data_key(data)
     make_differentiable(data)
     critic = self.critic(data, *args)
     score = self.score(data, *args)
     return data, score, critic