def before_step(self, step, feed_dict, depth=0):
        if self.gan.steps != 1:
            return

        defn = self.config.encoder.copy()
        klass = GANComponent.lookup_function(None, defn['class'])
        encode = klass(self.gan, defn).cuda()
        defn = self.config.optimizer.copy()
        klass = GANComponent.lookup_function(None, defn['class'])
        del defn["class"]
        self.optimizer = klass(
            list(encode.parameters()) + list(self.gan.generator.parameters()),
            **defn)

        for i in range(self.config.steps or 1000):
            self.optimizer.zero_grad()
            inp = self.gan.inputs.next()
            e = encode(inp)
            fake = self.gan.generator(e)

            loss = ((inp - fake)**2).mean()
            loss.backward()
            for p in (list(self.gan.g_parameters()) +
                      list(encode.parameters())):
                p.requires_grad = True
            #move = torch_grad(outputs=loss, inputs=list(self.gan.g_parameters())+list(encode.parameters()), retain_graph=True, create_graph=True)
            #print(list(self.gan.g_parameters())+list(encode.parameters()))
            self.optimizer.step()
            if self.config.verbose:
                print("[autoencode]", i, "loss", loss.item())
            if self.config.info and (i % 100) == 0:
                print("[autoencode]", i, "loss", loss.item())
    def before_step(self, step, feed_dict, depth=0):
        defn = self.config.optimizer.copy()
        klass = GANComponent.lookup_function(None, defn['class'])
        del defn["class"]
        self.optimizer = klass(self.gan.generator.parameters(), **defn)

        for i in range(self.config.steps or 1):
            self.optimizer.zero_grad()
            fake = self.gan.discriminator(self.gan.generator(self.gan.latent.instance)).mean()
            real = self.gan.discriminator(self.gan.inputs.sample).mean()
            loss = self.gan.loss.forward_adversarial_norm(real, fake)
            if loss == 0.0:
                if self.config.verbose:
                    print("[match support] No loss")
                break
            move = torch_grad(outputs=loss, inputs=self.gan.g_parameters(), retain_graph=True, create_graph=True)

            if self.config.regularize:
                move = torch_grad(outputs=(loss+sum([m.abs().sum() for m in move])), inputs=self.gan.g_parameters(), retain_graph=True, create_graph=True)
            for p, g in zip(self.gan.g_parameters(), move):
                if p._grad is not None:
                    p._grad.copy_(g)
                else:
                    pass
                    #print("Missing g")
            self.optimizer.step()
            if self.config.verbose:
                print("[match support]", i, "loss", loss.item())
            if self.config.loss_threshold and loss < self.config.loss_threshold:
                if self.config.info:
                    print("[match support] loss_threshold steps", i, "loss", loss.item())
                break
        if self.config.info and i == self.config.steps-1:
            print("[match support] loss_threshold steps", i, "loss", loss.item())
Beispiel #3
0
    def create_optimizer(self, name="optimizer"):
        defn = getattr(self.config, name) or self.config.optimizer
        defn = defn.copy()
        klass = GANComponent.lookup_function(None, defn['class'])
        del defn["class"]

        optimizer = self.trainable_gan.create_optimizer(klass, defn)
        return optimizer
 def __init__(self, gan, config):
     BaseDistribution.__init__(self, gan, config)
     klass = GANComponent.lookup_function(None, self.config['source'])
     self.source = klass(gan, config)
     self.current_channels = config["z"]
     self.current_width = 1
     self.current_height = 1
     self.current_input_size = config["z"]
     self.z = self.source.z
    def next(self):
        sample = self.source.next()
        if self.z_var is None:
            self.z_var = Variable(sample, requires_grad=True).cuda()
        defn = self.config.optimizer.copy()
        klass = GANComponent.lookup_function(None, defn['class'])
        del defn["class"]
        self.optimizer = klass([self.z_var], **defn)
        z = self.z_var
        z.data.copy_(sample)

        z.grad = torch.zeros_like(z)
        for i in range(self.config.steps or 1):
            self.optimizer.zero_grad()
            fake = self.gan.discriminator(self.gan.generator(
                self.hardtanh(z))).mean()
            real = self.gan.discriminator(self.gan.inputs.sample).mean()
            loss = self.gan.loss.forward_adversarial_norm(real, fake)
            if loss == 0.0:
                if self.config.verbose:
                    print("[optimize distribution] No loss")
                break
            z_move = torch_grad(outputs=loss,
                                inputs=z,
                                retain_graph=True,
                                create_graph=True)
            z_change = z_move[0].abs().mean()
            if self.config.z_change_threshold or self.config.verbose:
                if i == 0:
                    first_z_change = z_change
            z._grad.copy_(z_move[0])
            self.optimizer.step()
            if self.config.verbose:
                print("[optimize distribution]", i, "loss", loss.item(),
                      "mean movement", (z - sample).abs().mean().item(),
                      (z_change / first_z_change).item())
            if self.config.z_change_threshold and z_change / first_z_change < self.config.z_change_threshold:
                if self.config.info:
                    print("[optimize distribution] z_change_threshold steps",
                          i, "loss", loss.item(), "mean movement",
                          (z - sample).abs().mean().item(),
                          (z_change / first_z_change).item())
                break
            if self.config.loss_threshold and loss < self.config.loss_threshold:
                if self.config.info:
                    print("[optimize distribution] loss_threshold steps", i,
                          "loss", loss.item(), "mean movement",
                          (z - sample).abs().mean().item(),
                          (z_change / first_z_change).item())
                break

        if self.config.info and i == self.config.steps - 1:
            print("[optimize distribution] steps_threshold steps", i, "loss",
                  loss.item(), "mean movement",
                  (z - sample).abs().mean().item())
        self.instance = z
        return z
 def __init__(self, gan, config):
     BaseDistribution.__init__(self, gan, config)
     klass = GANComponent.lookup_function(None,
                                          self.config['source']['class'])
     self.source = klass(gan, config['source'])
     self.current_channels = config['source']["z"]
     self.current_width = 1
     self.current_height = 1
     self.current_input_size = config['source']["z"]
     self.z = self.source.z
     self.hardtanh = torch.nn.Hardtanh()
     self.relu = torch.nn.ReLU()
     self.z_var = None
Beispiel #7
0
def create_input(input_config):
    klass = GANComponent.lookup_function(None, input_config['class'])
    return klass(input_config)
Beispiel #8
0
 def create_input(self, blank=False, rank=None):
     klass = GANComponent.lookup_function(None, self.input_config['class'])
     self.input_config["blank"]=blank
     self.input_config["rank"]=rank
     return klass(self.input_config)