Exemplo n.º 1
0
    for k in range(splits):
        part_gen = preds_gen[k * (N // splits):(k + 1) * (N // splits), :]
        part_real = preds_real[k * (N // splits):(k + 1) * (N // splits), :]
        py_gen = np.mean(part_gen, axis=0)
        py_real = np.mean(part_real, axis=0)
        KL_gen_real = entropy(py_gen, py_real)
        scores = []
        for i in range(part_gen.shape[0]):
            pyx = part_gen[i, :]
            scores.append(entropy(pyx, py_gen))
        split_scores.append(np.exp(np.mean(scores) - KL_gen_real))

    return np.mean(split_scores), np.std(split_scores)


if __name__ == "__main__":

    gan = DCGAN()
    gan.load_model("checkpoints/trained_wgan/wgan-gen.pt", use_cuda=False)
    gen_imgs = gan.generate_img(n=32 * 2)
    gen_imgs = TensorDataset(gen_imgs.data, gen_imgs.data)
    print("Computing Inception score...")
    print(inception_score(gen_imgs, cuda=False, resize=True, splits=4))

    print("Computing Mode score...")
    real_imgs = utils.load_dataset("../data/celebA_all", 32)
    real_imgs = itertools.islice(real_imgs, 2)
    print(mode_score(gen_imgs, real_imgs, cuda=False, resize=True, splits=4))
    #utils.plot_error_bars()
    #x1, x2, err, ['GAN', 'WGAN'], 'Inception score', 'Inception score for different generative models', 'score.png')
Exemplo n.º 2
0
    if not os.path.isdir(args.dir):
        os.mkdir(args.dir)

    # Create random tensors from seeds
    if args.latent or args.screen:
        s0, s1 = args.latent if args.latent else args.screen
        space = 'latent' if args.latent else 'screen'
        print('Interpolating random seeds {:d} & {:d} in {} space...'.format(s0, s1, space))
        z0 = gan.create_latent_var(1, s0)
        z1 = gan.create_latent_var(1, s1)

        # Interpolate
        if args.latent:
            imgs = latent_lerp(gan, z0, z1, args.nb_frames)
        else:
            x0 = gan.generate_img(z0)
            x1 = gan.generate_img(z1)
            imgs = screen_lerp(x0, x1, args.nb_frames)

        # Save files
        for i, img in enumerate(imgs):
            img = utils.unnormalize(img)
            fname_in = '{}/frame{:d}.png'.format(args.dir, i)
            torchvision.utils.save_image(img, fname_in, padding=0)
            # Generate frames for perfect looping
            if args.video:
                fname_out = "{}/frame{:d}.png".format(args.dir, 2*args.nb_frames - i - 1)
                torchvision.utils.save_image(img, fname_out, padding=0)
        print("Interpolated {} images saved in {}".format(args.nb_frames, args.dir))

        # Make video
Exemplo n.º 3
0
class CelebA(object):
    """Implement DCGAN for CelebA dataset"""
    def __init__(self, train_params, ckpt_params, gan_params):
        # Training parameters
        self.root_dir = train_params['root_dir']
        self.gen_dir = train_params['gen_dir']
        self.batch_size = train_params['batch_size']
        self.train_len = train_params['train_len']
        self.learning_rate = train_params['learning_rate']
        self.momentum = train_params['momentum']
        self.optim = train_params['optim']
        self.use_cuda = train_params['use_cuda']

        # Checkpoint parameters (when, where)
        self.batch_report_interval = ckpt_params['batch_report_interval']
        self.ckpt_path = ckpt_params['ckpt_path']
        self.save_stats_interval = ckpt_params['save_stats_interval']

        # Create directories if they don't exist
        if not os.path.isdir(self.ckpt_path):
            os.mkdir(self.ckpt_path)
        if not os.path.isdir(self.gen_dir):
            os.mkdir(self.gen_dir)

        # GAN parameters
        self.gan_type = gan_params['gan_type']
        self.latent_dim = gan_params['latent_dim']
        self.n_critic = gan_params['n_critic']

        # Make sure report interval divides total num of batches
        self.num_batches = self.train_len // self.batch_size

        # Get ready to ruuummmmmmble
        self.compile()

    def compile(self):
        """Compile model (loss function, optimizers, etc.)"""

        # Create new GAN
        self.gan = DCGAN(self.gan_type, self.latent_dim, self.batch_size,
                         self.use_cuda)

        # Set optimizers for generator and discriminator
        if self.optim == 'adam':
            self.G_optimizer = optim.Adam(self.gan.G.parameters(),
                                          lr=self.learning_rate,
                                          betas=self.momentum)
            self.D_optimizer = optim.Adam(self.gan.D.parameters(),
                                          lr=self.learning_rate,
                                          betas=self.momentum)

        elif self.optim == 'rmsprop':
            self.G_optimizer = optim.RMSprop(self.gan.G.parameters(),
                                             lr=self.learning_rate)
            self.D_optimizer = optim.RMSprop(self.gan.D.parameters(),
                                             lr=self.learning_rate)

        else:
            raise NotImplementedError

        # CUDA support
        if torch.cuda.is_available() and self.use_cuda:
            self.gan = self.gan.cuda()

        # Create fixed latent variables for inference while training
        self.latent_vars = []
        for i in range(100):
            self.latent_vars.append(self.gan.create_latent_var(1))

    def save_stats(self, stats):
        """Save model statistics"""

        fname_pkl = '{}/{}-stats.pkl'.format(self.ckpt_path, self.gan_type)
        print('Saving model statistics to: {}'.format(fname_pkl))
        with open(fname_pkl, 'wb') as fp:
            pickle.dump(stats, fp)

    def eval(self, n, epoch=None, while_training=False):
        """Sample examples from generator's distribution"""

        # Evaluation mode
        self.gan.G.eval()

        # Montage size (square)
        m = int(np.sqrt(n))

        # Predict images to see progress
        for i in range(n):
            # Reuse fixed latent variables to keep random process intact
            if while_training:
                img = self.gan.generate_img(self.latent_vars[i])
            else:
                img = self.gan.generate_img()
            img = utils.unnormalize(img.squeeze())
            fname_in = '{}/test{:d}.png'.format(self.ckpt_path, i)
            torchvision.utils.save_image(img, fname_in)
        stack = 'montage {}/test* -tile {}x{} -geometry 64x64+1+1 \
            {}/epoch'.format(self.ckpt_path, m, m, self.ckpt_path)
        stack = stack + str(
            epoch + 1) + '.png' if epoch is not None else stack + '.png'
        sp.call(stack.split())
        for f in glob.glob('{}/test*'.format(self.ckpt_path)):
            os.remove(f)

    def train(self, nb_epochs, data_loader):
        """Train model on data"""

        # Initialize tracked quantities and prepare everything
        G_all_losses, D_all_losses, times = [], [], utils.AvgMeter()
        utils.format_hdr(self.gan, self.root_dir, self.train_len)
        start = datetime.datetime.now()

        g_iter, d_iter = 0, 0

        # Train
        for epoch in range(nb_epochs):
            print('EPOCH {:d} / {:d}'.format(epoch + 1, nb_epochs))
            G_losses, D_losses = utils.AvgMeter(), utils.AvgMeter()
            start_epoch = datetime.datetime.now()

            avg_time_per_batch = utils.AvgMeter()
            # Mini-batch SGD
            for batch_idx, (x, _) in enumerate(data_loader):

                # Critic update ratio
                if self.gan_type == 'wgan':
                    n_critic = 20 if g_iter < 50 or (
                        g_iter + 1) % 500 == 0 else self.n_critic
                else:
                    n_critic = self.n_critic

                # Training mode
                self.gan.G.train()

                # Discard last examples to simplify code
                if x.size(0) != self.batch_size:
                    break
                batch_start = datetime.datetime.now()

                # Print progress bar
                utils.progress_bar(batch_idx, self.batch_report_interval,
                                   G_losses.avg, D_losses.avg)

                x = Variable(x)
                if torch.cuda.is_available() and self.use_cuda:
                    x = x.cuda()

                # Update discriminator
                D_loss, fake_imgs = self.gan.train_D(x, self.D_optimizer,
                                                     self.batch_size)
                D_losses.update(D_loss, self.batch_size)
                d_iter += 1

                # Update generator
                if batch_idx % n_critic == 0:
                    G_loss = self.gan.train_G(self.G_optimizer,
                                              self.batch_size)
                    G_losses.update(G_loss, self.batch_size)
                    g_iter += 1

                batch_end = datetime.datetime.now()
                batch_time = int(
                    (batch_end - batch_start).total_seconds() * 1000)
                avg_time_per_batch.update(batch_time)

                # Report model statistics
                if (batch_idx % self.batch_report_interval == 0 and batch_idx) or \
                    self.batch_report_interval == self.num_batches:
                    G_all_losses.append(G_losses.avg)
                    D_all_losses.append(D_losses.avg)
                    utils.show_learning_stats(batch_idx, self.num_batches,
                                              G_losses.avg, D_losses.avg,
                                              avg_time_per_batch.avg)
                    [
                        k.reset()
                        for k in [G_losses, D_losses, avg_time_per_batch]
                    ]
                    self.eval(100, epoch=epoch, while_training=True)
                    # print('Critic iter: {}'.format(g_iter))

                # Save stats
                if batch_idx % self.save_stats_interval == 0 and batch_idx:
                    stats = dict(G_loss=G_all_losses, D_loss=D_all_losses)
                    self.save_stats(stats)

            # Save model
            utils.clear_line()
            print('Elapsed time for epoch: {}'.format(
                utils.time_elapsed_since(start_epoch)))
            self.gan.save_model(self.ckpt_path, epoch)
            self.eval(100, epoch=epoch, while_training=True)

        # Print elapsed time
        elapsed = utils.time_elapsed_since(start)
        print('Training done! Total elapsed time: {}\n'.format(elapsed))

        return G_loss, D_loss
class CelebA(object):
    """Implement DCGAN for CelebA dataset"""
    def __init__(self, train_params, ckpt_params, gan_params):
        # Training parameters
        self.root_dir = train_params['root_dir']
        self.batch_size = train_params['batch_size']
        self.train_len = train_params['train_len']
        self.learning_rate = train_params['learning_rate']
        self.momentum = train_params['momentum']
        self.optim = train_params['optim']
        self.use_cuda = train_params['use_cuda']

        # Checkpoint parameters (when, where)
        self.batch_report_interval = ckpt_params['batch_report_interval']
        self.ckpt_path = ckpt_params['ckpt_path']
        self.save_stats_interval = ckpt_params['save_stats_interval']

        # Create directories if they don't exist
        if not os.path.isdir(self.ckpt_path):
            print(self.ckpt_path)
            os.mkdir(self.ckpt_path)

        # GAN parameters
        self.gan_type = gan_params['gan_type']
        self.latent_dim = gan_params['latent_dim']
        self.n_critic = gan_params['n_critic']

        # Make sure report interval divides total num of batches
        self.num_batches = self.train_len // self.batch_size

        self.compile()
        #frequency weight
        self.freq_weight = 0

    def compile(self):
        """Compile model (loss function, optimizers, etc.)"""

        # Create new GAN
        self.gan = DCGAN(self.gan_type, self.latent_dim, self.batch_size,
                         self.use_cuda)

        # Set optimizers for generator and discriminator
        if self.optim == 'adam':
            self.G_optimizer = optim.Adam(self.gan.G.parameters(),
                                          lr=self.learning_rate,
                                          betas=self.momentum)
            self.D_optimizer = optim.Adam(self.gan.D.parameters(),
                                          lr=self.learning_rate,
                                          betas=self.momentum)

        elif self.optim == 'rmsprop':
            self.G_optimizer = optim.RMSprop(self.gan.G.parameters(),
                                             lr=self.learning_rate)
            self.D_optimizer = optim.RMSprop(self.gan.D.parameters(),
                                             lr=self.learning_rate)

        else:
            raise NotImplementedError

        # CUDA support
        if torch.cuda.is_available() and self.use_cuda:
            self.gan = self.gan.cuda()

    def save_stats(self, stats):
        """Save model statistics"""

        fname_pkl = '{}/{}-stats.pkl'.format(self.ckpt_path, self.gan_type)
        print('Saving model statistics to: {}'.format(fname_pkl))
        with open(fname_pkl, 'wb') as fp:
            pickle.dump(stats, fp)

    def test(self, epoch):
        fname_gen_pt = '{}/{}-gen-epoch-{}.pt'.format(self.ckpt_path,
                                                      self.gan_type, epoch + 1)
        self.gan.load_model(fname_gen_pt)

        directory = self.ckpt_path + "/testing/" + str(epoch + 1)
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Evaluation mode
        self.gan.G.eval()
        n = 10000
        # Predict images to see progress
        for i in range(n):
            img = self.gan.generate_img()
            img = utils.unnormalize(img.squeeze())
            fname_in = '{}/{:d}_test.png'.format(directory, i)
            torchvision.utils.save_image(img, fname_in)

    def train(self, nb_epochs, data_loader):
        """Train model on data"""

        # Initialize tracked quantities and prepare everything
        G_all_losses, D_all_losses, times = [], [], utils.AvgMeter()
        utils.format_hdr(self.gan, self.root_dir, self.train_len)
        start = datetime.datetime.now()

        g_iter, d_iter = 0, 0

        # Train
        for epoch in range(nb_epochs):
            print('EPOCH {:d} / {:d}'.format(epoch + 1, nb_epochs))
            G_losses, D_losses = utils.AvgMeter(), utils.AvgMeter()
            start_epoch = datetime.datetime.now()

            avg_time_per_batch = utils.AvgMeter()
            # Mini-batch SGD
            for batch_idx, (x, _) in enumerate(data_loader):

                # Critic update ratio
                if self.gan_type == 'wgan':
                    n_critic = 20 if g_iter < 50 or (
                        g_iter + 1) % 500 == 0 else self.n_critic
                else:
                    n_critic = self.n_critic

                # Training mode
                self.gan.G.train()

                # Discard last examples to simplify code
                if x.size(0) != self.batch_size:
                    break
                batch_start = datetime.datetime.now()

                # Print progress bar
                utils.progress_bar(batch_idx, self.batch_report_interval,
                                   G_losses.avg, D_losses.avg)

                x = Variable(x)
                if torch.cuda.is_available() and self.use_cuda:
                    x = x.cuda()

                self.freq_weight = (epoch + 1) / nb_epochs
                # Update discriminator
                D_loss, fake_imgs = self.gan.train_D(x, self.freq_weight,
                                                     self.D_optimizer,
                                                     self.batch_size)
                D_losses.update(D_loss, self.batch_size)
                d_iter += 1

                # Update generator
                if batch_idx % n_critic == 0:
                    G_loss = self.gan.train_G(self.freq_weight,
                                              self.G_optimizer,
                                              self.batch_size)
                    G_losses.update(G_loss, self.batch_size)
                    g_iter += 1

                batch_end = datetime.datetime.now()
                batch_time = int(
                    (batch_end - batch_start).total_seconds() * 1000)
                avg_time_per_batch.update(batch_time)

                # Report model statistics
                if (batch_idx % self.batch_report_interval == 0 and batch_idx) or \
                    self.batch_report_interval == self.num_batches:
                    G_all_losses.append(G_losses.avg)
                    D_all_losses.append(D_losses.avg)
                    utils.show_learning_stats(batch_idx, self.num_batches,
                                              G_losses.avg, D_losses.avg,
                                              avg_time_per_batch.avg)
                    [
                        k.reset()
                        for k in [G_losses, D_losses, avg_time_per_batch]
                    ]

                # Save stats
                if batch_idx % self.save_stats_interval == 0 and batch_idx:
                    stats = dict(G_loss=G_all_losses, D_loss=D_all_losses)
                    self.save_stats(stats)

            # Save model
            utils.clear_line()
            print('Elapsed time for epoch: {}'.format(
                utils.time_elapsed_since(start_epoch)))
            self.gan.save_model(self.ckpt_path, epoch, False)
            # Generating
            model.test(epoch)

        # Print elapsed time
        elapsed = utils.time_elapsed_since(start)
        print('Training done! Total elapsed time: {}\n'.format(elapsed))

        return G_loss, D_loss