def __init__(self, depth=7, latent_size=512, learning_rate=0.001, beta_1=0, beta_2=0.99, eps=1e-8, drift=0.001, n_critic=1, use_eql=True, loss="wgan-gp", use_ema=True, ema_decay=0.999, device=th.device("cuda")): """ constructor for the class :param depth: depth of the GAN (will be used for each generator and discriminator) :param latent_size: latent size of the manifold used by the GAN :param learning_rate: learning rate for Adam :param beta_1: beta_1 for Adam :param beta_2: beta_2 for Adam :param eps: epsilon for Adam :param n_critic: number of times to update discriminator (Used only if loss is wgan or wgan-gp) :param drift: drift penalty for the (Used only if loss is wgan or wgan-gp) :param use_eql: whether to use equalized learning rate :param loss: the loss function to be used Can either be a string => ["wgan-gp", "wgan", "lsgan", "lsgan-with-sigmoid"] Or an instance of GANLoss :param use_ema: boolean for whether to use exponential moving averages :param ema_decay: value of mu for ema :param device: device to run the GAN on (GPU / CPU) """ from torch.optim import Adam # Create the Generator and the Discriminator self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device) self.dis = Discriminator(depth, latent_size, use_eql=use_eql).to(device) # state of the object self.latent_size = latent_size self.depth = depth self.use_ema = use_ema self.ema_decay = ema_decay self.n_critic = n_critic self.use_eql = use_eql self.device = device self.drift = drift # define the optimizers for the discriminator and generator self.gen_optim = Adam(self.gen.parameters(), lr=learning_rate, betas=(beta_1, beta_2), eps=eps) self.dis_optim = Adam(self.dis.parameters(), lr=learning_rate, betas=(beta_1, beta_2), eps=eps) # define the loss function used for training the GAN self.loss = self.__setup_loss(loss) # setup the ema for the generator if self.use_ema: from networks.CustomLayers import EMA self.ema = EMA(self.ema_decay) self.__register_generator_to_ema()
class ConditionalProGAN: """ Wrapper around the Generator and the Discriminator """ def __init__(self, embedding_size, depth=7, latent_size=512, compressed_latent_size=128, learning_rate=0.001, beta_1=0, beta_2=0.99, eps=1e-8, drift=0.001, n_critic=1, use_eql=True, loss="wgan-gp", use_ema=True, ema_decay=0.999, device=th.device("cuda")): """ constructor for the class :param embedding_size: size of the encoded text embeddings :param depth: depth of the GAN (will be used for each generator and discriminator) :param latent_size: latent size of the manifold used by the GAN :param compressed_latent_size: size of the compressed latent vectors :param learning_rate: learning rate for Adam :param beta_1: beta_1 for Adam :param beta_2: beta_2 for Adam :param eps: epsilon for Adam :param n_critic: number of times to update discriminator (Used only if loss is wgan or wgan-gp) :param drift: drift penalty for the (Used only if loss is wgan or wgan-gp) :param use_eql: whether to use equalized learning rate :param loss: the loss function to be used Can either be a string => ["wgan-gp", "wgan"] Or an instance of GANLoss :param use_ema: boolean for whether to use exponential moving averages :param ema_decay: value of mu for ema :param device: device to run the GAN on (GPU / CPU) """ from torch.optim import Adam # Create the Generator and the Discriminator self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device) self.dis = ConditionalDiscriminator(depth, embedding_size, compressed_latent_size, use_eql=use_eql).to(device) # state of the object self.latent_size = latent_size self.compressed_latent_size = compressed_latent_size self.depth = depth self.use_ema = use_ema self.ema_decay = ema_decay self.n_critic = n_critic self.use_eql = use_eql self.device = device self.drift = drift # define the optimizers for the discriminator and generator self.gen_optim = Adam(self.gen.parameters(), lr=learning_rate, betas=(beta_1, beta_2), eps=eps) self.dis_optim = Adam(self.dis.parameters(), lr=learning_rate, betas=(beta_1, beta_2), eps=eps) # define the loss function used for training the GAN self.loss = self.__setup_loss(loss) # setup the ema for the generator if self.use_ema: from networks.CustomLayers import EMA self.ema = EMA(self.ema_decay) self.__register_generator_to_ema() def __register_generator_to_ema(self): for name, param in self.gen.named_parameters(): if param.requires_grad: self.ema.register(name, param.data) def __apply_ema_on_generator(self): for name, param in self.gen.named_parameters(): if param.requires_grad: param.data = self.ema(name, param.data) def __setup_loss(self, loss): import networks.Losses as losses if isinstance(loss, str): loss = loss.lower() # lowercase the string if loss == "wgan": loss = losses.CondWGAN_GP(self.device, self.dis, self.drift, use_gp=False) # note if you use just wgan, you will have to use weight clipping # in order to prevent gradient exploding elif loss == "wgan-gp": loss = losses.CondWGAN_GP(self.device, self.dis, self.drift, use_gp=True) else: raise ValueError("Unknown loss function requested") elif not isinstance(loss, losses.ConditionalGANLoss): raise ValueError("loss is neither an instance of GANLoss nor a string") return loss def optimize_discriminator(self, noise, real_batch, latent_vector, depth, alpha, use_matching_aware=True): """ performs one step of weight update on discriminator using the batch of data :param noise: input noise of sample generation :param real_batch: real samples batch :param latent_vector: (conditional latent vector) :param depth: current depth of optimization :param alpha: current alpha for fade-in :param use_matching_aware: whether to use matching aware discrimination :return: current loss (Wasserstein loss) """ from torch.nn import AvgPool2d from torch.nn.functional import upsample # downsample the real_batch for the given depth down_sample_factor = int(np.power(2, self.depth - depth - 1)) prior_downsample_factor = max(int(np.power(2, self.depth - depth)), 0) ds_real_samples = AvgPool2d(down_sample_factor)(real_batch) if depth > 0: prior_ds_real_samples = upsample(AvgPool2d(prior_downsample_factor)(real_batch), scale_factor=2) else: prior_ds_real_samples = ds_real_samples # real samples are a combination of ds_real_samples and prior_ds_real_samples real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples) loss_val = 0 for _ in range(self.n_critic): # generate a batch of samples fake_samples = self.gen(noise, depth, alpha).detach() loss = self.loss.dis_loss(real_samples, fake_samples, latent_vector, depth, alpha) if use_matching_aware: # calculate the matching aware distribution loss mis_match_text = latent_vector[np.random.permutation(latent_vector.shape[0]), :] m_a_d = self.dis(real_samples, mis_match_text, depth, alpha) loss = loss + th.mean(m_a_d) # optimize discriminator self.dis_optim.zero_grad() loss.backward() self.dis_optim.step() loss_val += loss.item() return loss_val / self.n_critic def optimize_generator(self, noise, latent_vector, depth, alpha): """ performs one step of weight update on generator for the given batch_size :param noise: input random noise required for generating samples :param latent_vector: (conditional latent vector) :param depth: depth of the network at which optimization is done :param alpha: value of alpha for fade-in effect :return: current loss (Wasserstein estimate) """ # generate fake samples: fake_samples = self.gen(noise, depth, alpha) # TODO: Change this implementation for making it compatible for relativisticGAN loss = self.loss.gen_loss(None, fake_samples, latent_vector, depth, alpha) # optimize the generator self.gen_optim.zero_grad() loss.backward(retain_graph=True) self.gen_optim.step() # if use_ema is true, apply ema to the generator parameters if self.use_ema: self.__apply_ema_on_generator() # return the loss value return loss.item()