Exemple #1
0
def get_models(latent_dim, model_dim, device, output_dim, channels, init=True):
    generator = Generator(latent_dim, model_dim, channels).to(device)
    critic = Critic(model_dim, output_dim, channels).to(device)
    if init:
        generator.apply(__weights_init_normal)
        critic.apply(__weights_init_normal)
    return generator, critic
Exemple #2
0
class WGAN(object):
    def __init__(self, image_size, input_channels, hidden_channels, output_channels, latent_dimension, lr, device, clamp=0.01, gp_weight=10):
        self.image_size = image_size
        self.input_channels = input_channels
        self.hidden_chanels = hidden_channels
        self.output_channels = output_channels
        self.latent_dimension = latent_dimension
        self.device = device
        self.clamp = clamp
        self.gp_weight = gp_weight

        self.critic = Critic(image_size, hidden_channels,
                             input_channels).to(device)
        self.generator = Generator(
            image_size, latent_dimension, hidden_channels, output_channels).to(device)

        self.critic.apply(self.weights_init)
        self.generator.apply(self.weights_init)

        self.optimizer_critic = torch.optim.RMSprop(
            self.critic.parameters(), lr)
        self.optimizer_gen = torch.optim.RMSprop(
            self.generator.parameters(), lr)

        self.optimizer_critic = torch.optim.Adam(
            self.critic.parameters(), lr, betas=(0, 0.9))
        self.optimizer_gen = torch.optim.Adam(
            self.generator.parameters(), lr, betas=(0, 0.9))

        self.critic_losses = []
        self.gen_losses = []

        self.losses = []

    def critique(self, x):
        return self.critic(x)

    def generate(self, z):
        return self.generator(z)

    def load_model(self, load_name):
        self.critic.load_state_dict(torch.load(
            load_name + '_critic', map_location='cpu'))
        self.generator.load_state_dict(torch.load(
            load_name + '_generator', map_location='cpu'))

    def _gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_data)

        if self.device != 'cpu':
            alpha = alpha.cuda()
        interpolated = alpha * real_data.data + \
            (1 - alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        if self.device != 'cpu':
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.critique(interpolated).to(self.device)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                               grad_outputs=torch.ones(prob_interpolated.size()).to(self.device), create_graph=True, retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1) ** 2).mean()

    def weights_init(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

    def normalise_pixels(self, images):
        normalised_images = (images - 0.5)/0.5
        return normalised_images

    def denormalise_pixels(self, images):
        denormalised_images = (images * 0.5) + 0.5
        #denormalised_images = np.clip(denormalised_images, 0, 1)
        return denormalised_images

    def train(self, data_loader, num_epochs, n_iter=5, log_iter=1, test_iter=1, test_latents=None, save_name=None, load_name=None):
        if load_name is not None:
            self.load_model(load_name)

        d_err = 0
        g_err = 0
        for epoch in range(num_epochs):
            if epoch % test_iter == 0 and test_latents is not None:
                images = self.generate(test_latents)
                self.display_images(images)

            if epoch < 0:
                critic_iters = 100

            else:
                critic_iters = n_iter

            j = 1
            for i, batch in enumerate(data_loader):
                real_images, _ = batch
                real_images = self.normalise_pixels(real_images)

                real_images = real_images.to(self.device)
                d_real = self.critique(real_images)

                self.optimizer_critic.zero_grad()
                # Generate a batch of latents
                latent_zs = generate_latent(
                    self.latent_dimension, len(real_images)).to(self.device)

                # Transform latents into images using the generator
                fake_images = self.generate(latent_zs).to(self.device)

                # Classify fakes with discriminator
                d_fake = self.critique(fake_images.detach())

                gradient_penalty = self._gradient_penalty(
                    real_images, fake_images)

                d_err = -(torch.mean(d_real) - torch.mean(d_fake)) + \
                    gradient_penalty
                d_err.backward()

                # Gradient descent step for discriminator
                self.optimizer_critic.step()

#        for p in self.critic.parameters():
#          p.data.clamp_(-self.clamp, self.clamp)

                j += 1

                if j % critic_iters == 0:
                    ### Generator ###
                    self.optimizer_gen.zero_grad()

                    # Generate a batch of latents
                    latent_zs = generate_latent(
                        self.latent_dimension, len(real_images)).to(self.device)
                    # Transform latents into images using the generator
                    fake_images = self.generate(latent_zs).to(self.device)
                    d_fake = self.critique(fake_images)
                    g_err = -torch.mean(d_fake)
                    g_err.backward()

                    self.optimizer_gen.step()

                    j = 1

            if epoch % log_iter == 0:
                self.critic_losses.append(d_err)
                self.gen_losses.append(g_err)
                print('Epoch %d: Wasserstein Loss: %.4f\t Generator Loss: %.4f' % (
                    epoch, -d_err, g_err))

                if save_name is not None:
                    torch.save(self.critic.state_dict(), save_name + '_critic')
                    torch.save(self.generator.state_dict(),
                               save_name + '_generator')

                    with open(save_name + '_critic_losses.pkl', 'wb') as f:
                        pickle.dump(self.critic_losses, f)

                    with open(save_name + '_gen_losses.pkl', 'wb') as f:
                        pickle.dump(self.gen_losses, f)

    def display_images(self, images):
        k = 0

        fig, ax = plt.subplots(1, 5)

        for i in range(5):
            image = self.denormalise_pixels(
                images[k]).cpu().detach().permute(1, 2, 0)

            ax[i].imshow(image)

            ax[i].spines["top"].set_visible(False)
            ax[i].spines["right"].set_visible(False)
            ax[i].spines["bottom"].set_visible(False)
            ax[i].spines["left"].set_visible(False)
            ax[i].set_xticks([])
            ax[i].set_yticks([])

            k += 1

        fig.set_figheight(20)
        fig.set_figwidth(20)

        plt.show()