예제 #1
0
    def __init__(self, depth=7, latent_size=512, learning_rate=0.001, beta_1=0,
                 beta_2=0.99, eps=1e-8, drift=0.001, n_critic=1, use_eql=True,
                 loss="wgan-gp", use_ema=True, ema_decay=0.999,
                 device=th.device("cuda")):
        """
        constructor for the class
        :param depth: depth of the GAN (will be used for each generator and discriminator)
        :param latent_size: latent size of the manifold used by the GAN
        :param learning_rate: learning rate for Adam
        :param beta_1: beta_1 for Adam
        :param beta_2: beta_2 for Adam
        :param eps: epsilon for Adam
        :param n_critic: number of times to update discriminator
                         (Used only if loss is wgan or wgan-gp)
        :param drift: drift penalty for the
                      (Used only if loss is wgan or wgan-gp)
        :param use_eql: whether to use equalized learning rate
        :param loss: the loss function to be used
                     Can either be a string =>
                          ["wgan-gp", "wgan", "lsgan", "lsgan-with-sigmoid"]
                     Or an instance of GANLoss
        :param use_ema: boolean for whether to use exponential moving averages
        :param ema_decay: value of mu for ema
        :param device: device to run the GAN on (GPU / CPU)
        """

        from torch.optim import Adam

        # Create the Generator and the Discriminator
        self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)
        self.dis = Discriminator(depth, latent_size, use_eql=use_eql).to(device)

        # state of the object
        self.latent_size = latent_size
        self.depth = depth
        self.use_ema = use_ema
        self.ema_decay = ema_decay
        self.n_critic = n_critic
        self.use_eql = use_eql
        self.device = device
        self.drift = drift

        # define the optimizers for the discriminator and generator
        self.gen_optim = Adam(self.gen.parameters(), lr=learning_rate,
                              betas=(beta_1, beta_2), eps=eps)

        self.dis_optim = Adam(self.dis.parameters(), lr=learning_rate,
                              betas=(beta_1, beta_2), eps=eps)

        # define the loss function used for training the GAN
        self.loss = self.__setup_loss(loss)

        # setup the ema for the generator
        if self.use_ema:
            from networks.CustomLayers import EMA
            self.ema = EMA(self.ema_decay)
            self.__register_generator_to_ema()
예제 #2
0
class ConditionalProGAN:
    """ Wrapper around the Generator and the Discriminator """

    def __init__(self, embedding_size, depth=7, latent_size=512, compressed_latent_size=128,
                 learning_rate=0.001, beta_1=0, beta_2=0.99,
                 eps=1e-8, drift=0.001, n_critic=1, use_eql=True,
                 loss="wgan-gp", use_ema=True, ema_decay=0.999,
                 device=th.device("cuda")):
        """
        constructor for the class
        :param embedding_size: size of the encoded text embeddings
        :param depth: depth of the GAN (will be used for each generator and discriminator)
        :param latent_size: latent size of the manifold used by the GAN
        :param compressed_latent_size: size of the compressed latent vectors
        :param learning_rate: learning rate for Adam
        :param beta_1: beta_1 for Adam
        :param beta_2: beta_2 for Adam
        :param eps: epsilon for Adam
        :param n_critic: number of times to update discriminator
                         (Used only if loss is wgan or wgan-gp)
        :param drift: drift penalty for the
                      (Used only if loss is wgan or wgan-gp)
        :param use_eql: whether to use equalized learning rate
        :param loss: the loss function to be used
                     Can either be a string =>
                          ["wgan-gp", "wgan"]
                     Or an instance of GANLoss
        :param use_ema: boolean for whether to use exponential moving averages
        :param ema_decay: value of mu for ema
        :param device: device to run the GAN on (GPU / CPU)
        """

        from torch.optim import Adam

        # Create the Generator and the Discriminator
        self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)
        self.dis = ConditionalDiscriminator(depth, embedding_size, compressed_latent_size,
                                            use_eql=use_eql).to(device)

        # state of the object
        self.latent_size = latent_size
        self.compressed_latent_size = compressed_latent_size
        self.depth = depth
        self.use_ema = use_ema
        self.ema_decay = ema_decay
        self.n_critic = n_critic
        self.use_eql = use_eql
        self.device = device
        self.drift = drift

        # define the optimizers for the discriminator and generator
        self.gen_optim = Adam(self.gen.parameters(), lr=learning_rate,
                              betas=(beta_1, beta_2), eps=eps)

        self.dis_optim = Adam(self.dis.parameters(), lr=learning_rate,
                              betas=(beta_1, beta_2), eps=eps)

        # define the loss function used for training the GAN
        self.loss = self.__setup_loss(loss)

        # setup the ema for the generator
        if self.use_ema:
            from networks.CustomLayers import EMA
            self.ema = EMA(self.ema_decay)
            self.__register_generator_to_ema()

    def __register_generator_to_ema(self):
        for name, param in self.gen.named_parameters():
            if param.requires_grad:
                self.ema.register(name, param.data)

    def __apply_ema_on_generator(self):
        for name, param in self.gen.named_parameters():
            if param.requires_grad:
                param.data = self.ema(name, param.data)

    def __setup_loss(self, loss):
        import networks.Losses as losses

        if isinstance(loss, str):
            loss = loss.lower()  # lowercase the string
            if loss == "wgan":
                loss = losses.CondWGAN_GP(self.device, self.dis, self.drift, use_gp=False)
                # note if you use just wgan, you will have to use weight clipping
                # in order to prevent gradient exploding

            elif loss == "wgan-gp":
                loss = losses.CondWGAN_GP(self.device, self.dis, self.drift, use_gp=True)

            else:
                raise ValueError("Unknown loss function requested")

        elif not isinstance(loss, losses.ConditionalGANLoss):
            raise ValueError("loss is neither an instance of GANLoss nor a string")

        return loss

    def optimize_discriminator(self, noise, real_batch, latent_vector, depth, alpha,
                               use_matching_aware=True):
        """
        performs one step of weight update on discriminator using the batch of data
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
        :param latent_vector: (conditional latent vector)
        :param depth: current depth of optimization
        :param alpha: current alpha for fade-in
        :param use_matching_aware: whether to use matching aware discrimination
        :return: current loss (Wasserstein loss)
        """
        from torch.nn import AvgPool2d
        from torch.nn.functional import upsample

        # downsample the real_batch for the given depth
        down_sample_factor = int(np.power(2, self.depth - depth - 1))
        prior_downsample_factor = max(int(np.power(2, self.depth - depth)), 0)

        ds_real_samples = AvgPool2d(down_sample_factor)(real_batch)

        if depth > 0:
            prior_ds_real_samples = upsample(AvgPool2d(prior_downsample_factor)(real_batch),
                                             scale_factor=2)
        else:
            prior_ds_real_samples = ds_real_samples

        # real samples are a combination of ds_real_samples and prior_ds_real_samples
        real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples)

        loss_val = 0
        for _ in range(self.n_critic):
            # generate a batch of samples
            fake_samples = self.gen(noise, depth, alpha).detach()

            loss = self.loss.dis_loss(real_samples, fake_samples,
                                      latent_vector, depth, alpha)

            if use_matching_aware:
                # calculate the matching aware distribution loss
                mis_match_text = latent_vector[np.random.permutation(latent_vector.shape[0]), :]
                m_a_d = self.dis(real_samples, mis_match_text, depth, alpha)
                loss = loss + th.mean(m_a_d)

            # optimize discriminator
            self.dis_optim.zero_grad()
            loss.backward()
            self.dis_optim.step()

            loss_val += loss.item()

        return loss_val / self.n_critic

    def optimize_generator(self, noise, latent_vector, depth, alpha):
        """
        performs one step of weight update on generator for the given batch_size
        :param noise: input random noise required for generating samples
        :param latent_vector: (conditional latent vector)
        :param depth: depth of the network at which optimization is done
        :param alpha: value of alpha for fade-in effect
        :return: current loss (Wasserstein estimate)
        """

        # generate fake samples:
        fake_samples = self.gen(noise, depth, alpha)

        # TODO: Change this implementation for making it compatible for relativisticGAN
        loss = self.loss.gen_loss(None, fake_samples, latent_vector, depth, alpha)

        # optimize the generator
        self.gen_optim.zero_grad()
        loss.backward(retain_graph=True)
        self.gen_optim.step()

        # if use_ema is true, apply ema to the generator parameters
        if self.use_ema:
            self.__apply_ema_on_generator()

        # return the loss value
        return loss.item()