def before_step(self, step, feed_dict, depth=0): if self.gan.steps != 1: return defn = self.config.encoder.copy() klass = GANComponent.lookup_function(None, defn['class']) encode = klass(self.gan, defn).cuda() defn = self.config.optimizer.copy() klass = GANComponent.lookup_function(None, defn['class']) del defn["class"] self.optimizer = klass( list(encode.parameters()) + list(self.gan.generator.parameters()), **defn) for i in range(self.config.steps or 1000): self.optimizer.zero_grad() inp = self.gan.inputs.next() e = encode(inp) fake = self.gan.generator(e) loss = ((inp - fake)**2).mean() loss.backward() for p in (list(self.gan.g_parameters()) + list(encode.parameters())): p.requires_grad = True #move = torch_grad(outputs=loss, inputs=list(self.gan.g_parameters())+list(encode.parameters()), retain_graph=True, create_graph=True) #print(list(self.gan.g_parameters())+list(encode.parameters())) self.optimizer.step() if self.config.verbose: print("[autoencode]", i, "loss", loss.item()) if self.config.info and (i % 100) == 0: print("[autoencode]", i, "loss", loss.item())
def before_step(self, step, feed_dict, depth=0): defn = self.config.optimizer.copy() klass = GANComponent.lookup_function(None, defn['class']) del defn["class"] self.optimizer = klass(self.gan.generator.parameters(), **defn) for i in range(self.config.steps or 1): self.optimizer.zero_grad() fake = self.gan.discriminator(self.gan.generator(self.gan.latent.instance)).mean() real = self.gan.discriminator(self.gan.inputs.sample).mean() loss = self.gan.loss.forward_adversarial_norm(real, fake) if loss == 0.0: if self.config.verbose: print("[match support] No loss") break move = torch_grad(outputs=loss, inputs=self.gan.g_parameters(), retain_graph=True, create_graph=True) if self.config.regularize: move = torch_grad(outputs=(loss+sum([m.abs().sum() for m in move])), inputs=self.gan.g_parameters(), retain_graph=True, create_graph=True) for p, g in zip(self.gan.g_parameters(), move): if p._grad is not None: p._grad.copy_(g) else: pass #print("Missing g") self.optimizer.step() if self.config.verbose: print("[match support]", i, "loss", loss.item()) if self.config.loss_threshold and loss < self.config.loss_threshold: if self.config.info: print("[match support] loss_threshold steps", i, "loss", loss.item()) break if self.config.info and i == self.config.steps-1: print("[match support] loss_threshold steps", i, "loss", loss.item())
def create_optimizer(self, name="optimizer"): defn = getattr(self.config, name) or self.config.optimizer defn = defn.copy() klass = GANComponent.lookup_function(None, defn['class']) del defn["class"] optimizer = self.trainable_gan.create_optimizer(klass, defn) return optimizer
def __init__(self, gan, config): BaseDistribution.__init__(self, gan, config) klass = GANComponent.lookup_function(None, self.config['source']) self.source = klass(gan, config) self.current_channels = config["z"] self.current_width = 1 self.current_height = 1 self.current_input_size = config["z"] self.z = self.source.z
def next(self): sample = self.source.next() if self.z_var is None: self.z_var = Variable(sample, requires_grad=True).cuda() defn = self.config.optimizer.copy() klass = GANComponent.lookup_function(None, defn['class']) del defn["class"] self.optimizer = klass([self.z_var], **defn) z = self.z_var z.data.copy_(sample) z.grad = torch.zeros_like(z) for i in range(self.config.steps or 1): self.optimizer.zero_grad() fake = self.gan.discriminator(self.gan.generator( self.hardtanh(z))).mean() real = self.gan.discriminator(self.gan.inputs.sample).mean() loss = self.gan.loss.forward_adversarial_norm(real, fake) if loss == 0.0: if self.config.verbose: print("[optimize distribution] No loss") break z_move = torch_grad(outputs=loss, inputs=z, retain_graph=True, create_graph=True) z_change = z_move[0].abs().mean() if self.config.z_change_threshold or self.config.verbose: if i == 0: first_z_change = z_change z._grad.copy_(z_move[0]) self.optimizer.step() if self.config.verbose: print("[optimize distribution]", i, "loss", loss.item(), "mean movement", (z - sample).abs().mean().item(), (z_change / first_z_change).item()) if self.config.z_change_threshold and z_change / first_z_change < self.config.z_change_threshold: if self.config.info: print("[optimize distribution] z_change_threshold steps", i, "loss", loss.item(), "mean movement", (z - sample).abs().mean().item(), (z_change / first_z_change).item()) break if self.config.loss_threshold and loss < self.config.loss_threshold: if self.config.info: print("[optimize distribution] loss_threshold steps", i, "loss", loss.item(), "mean movement", (z - sample).abs().mean().item(), (z_change / first_z_change).item()) break if self.config.info and i == self.config.steps - 1: print("[optimize distribution] steps_threshold steps", i, "loss", loss.item(), "mean movement", (z - sample).abs().mean().item()) self.instance = z return z
def __init__(self, gan, config): BaseDistribution.__init__(self, gan, config) klass = GANComponent.lookup_function(None, self.config['source']['class']) self.source = klass(gan, config['source']) self.current_channels = config['source']["z"] self.current_width = 1 self.current_height = 1 self.current_input_size = config['source']["z"] self.z = self.source.z self.hardtanh = torch.nn.Hardtanh() self.relu = torch.nn.ReLU() self.z_var = None
def create_input(input_config): klass = GANComponent.lookup_function(None, input_config['class']) return klass(input_config)
def create_input(self, blank=False, rank=None): klass = GANComponent.lookup_function(None, self.input_config['class']) self.input_config["blank"]=blank self.input_config["rank"]=rank return klass(self.input_config)