Esempio n. 1
0
class Trainer(object):
    def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs, pre_trained_disc_B,
                 pre_trained_gen_B):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        # forward gan
        if is_cuda:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan').cuda())
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan').cuda())
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan').cuda())
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan').cuda())
        else:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan'))
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan'))
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan'))
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan'))

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if pre_trained_disc_B:
            self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B))
        else:
            self.discriminator2.apply(Utils.weights_init)

        if pre_trained_gen_B:
            self.generator2.load_state_dict(torch.load(pre_trained_gen_B))
        else:
            self.generator2.apply(Utils.weights_init)

        if dataset == 'birds':
            with open('./data/birds_vocab.pkl', 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             dataset_type='birds',
                                             vocab=self.vocab,
                                             split=split)
        elif dataset == 'flowers':
            with open('./data/flowers_vocab.pkl', 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             dataset_type='flowers',
                                             vocab=self.vocab,
                                             split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter
        self.num_workers = num_workers

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers,
                                      collate_fn=collate_fn)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))
        self.optimG2 = torch.optim.Adam(self.generator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))

        self.checkpoints_path = './checkpoints/'
        self.save_path = save_path
        self.type = type

        # TODO: put these as runtime.py params
        self.embed_size = 256
        self.hidden_size = 512
        self.num_layers = 1

        self.gen_pretrain_num_epochs = 100
        self.disc_pretrain_num_epochs = 20

        self.figure_path = './figures/'
        if is_cuda:
            self.caption_generator = CaptionGenerator(self.embed_size,
                                                      self.hidden_size,
                                                      len(self.vocab),
                                                      self.num_layers).cuda()
            self.caption_discriminator = CaptionDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers).cuda()
        else:
            self.caption_generator = CaptionGenerator(self.embed_size,
                                                      self.hidden_size,
                                                      len(self.vocab),
                                                      self.num_layers)
            self.caption_discriminator = CaptionDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers)

        pretrained_caption_gen = './checkpoints/pretrained-generator-20.pkl'
        pretrained_caption_disc = './checkpoints/pretrained-discriminator-5.pkl'

        if os.path.exists(pretrained_caption_gen):
            print('loaded pretrained caption generator')
            self.caption_generator.load_state_dict(
                torch.load(pretrained_caption_gen))

        if os.path.exists(pretrained_caption_disc):
            print('loaded pretrained caption discriminator')
            self.caption_discriminator.load_state_dict(
                torch.load(pretrained_caption_disc))

        self.optim_captionG = torch.optim.Adam(
            list(self.caption_generator.parameters()))
        self.optim_captionD = torch.optim.Adam(
            list(self.caption_discriminator.parameters()))

    def train(self, cls=False, interp=False):

        if self.type == 'gan':
            self._train_gan(cls, interp)
        elif self.type == 'stackgan':
            self._train_stack_gan(cls, interp)
        elif self.type == 'pretrain_caption':
            self._pretrain_caption()

    def _pretrain_caption(self):

        # Create model directory
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)

        if not os.path.exists(self.figure_path):
            os.makedirs(self.figure_path)

        # Build the models (Gen)
        generator = CaptionGenerator(self.embed_size, self.hidden_size,
                                     len(self.vocab), self.num_layers)

        # Build the models (Disc)
        discriminator = CaptionDiscriminator(self.embed_size, self.hidden_size,
                                             len(self.vocab), self.num_layers)

        if torch.cuda.is_available():
            generator.cuda()
            discriminator.cuda()

        # Loss and Optimizer (Gen)
        mle_criterion = nn.CrossEntropyLoss()
        params_gen = list(generator.parameters())
        optimizer_gen = torch.optim.Adam(params_gen)

        # Loss and Optimizer (Disc)
        params_disc = list(discriminator.parameters())
        optimizer_disc = torch.optim.Adam(params_disc)

        disc_losses = []
        gen_losses = []
        for epoch in tqdm(
                range(
                    max([
                        int(self.gen_pretrain_num_epochs),
                        int(self.disc_pretrain_num_epochs)
                    ]))):
            for sample in self.data_loader:
                images = sample['right_images128']  # 64x3x128x128
                captions = sample['captions']
                lengths = sample['lengths']
                wrong_captions = sample['wrong_captions']
                wrong_lengths = sample['wrong_lengths']

                images = to_var(images, volatile=True)
                captions = to_var(captions)
                wrong_captions = to_var(wrong_captions)

                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                if epoch < int(self.gen_pretrain_num_epochs):
                    generator.zero_grad()
                    outputs, _ = generator(images, captions, lengths)
                    loss_gen = mle_criterion(outputs, targets)
                    gen_losses.append(loss_gen.cpu().data.numpy()[0])
                    loss_gen.backward()
                    optimizer_gen.step()

                if epoch < int(self.disc_pretrain_num_epochs):
                    discriminator.zero_grad()
                    rewards_real = discriminator(images, captions, lengths)
                    # rewards_fake = discriminator(images, sampled_captions, sampled_lengths)
                    rewards_wrong = discriminator(images, wrong_captions,
                                                  wrong_lengths)
                    real_loss = -torch.mean(torch.log(rewards_real))
                    # fake_loss = -torch.mean(torch.clamp(torch.log(1 - rewards_fake), min=-1000))
                    wrong_loss = -torch.mean(
                        torch.clamp(torch.log(1 - rewards_wrong), min=-1000))
                    loss_disc = real_loss + wrong_loss  # + fake_loss, no fake_loss because this is pretraining

                    disc_losses.append(loss_disc.cpu().data.numpy()[0])
                    loss_disc.backward()
                    optimizer_disc.step()

        # Save pretrained models
        torch.save(
            discriminator.state_dict(),
            os.path.join(
                self.checkpoints_path, 'pretrained-discriminator-%d.pkl' %
                int(self.disc_pretrain_num_epochs)))
        torch.save(
            generator.state_dict(),
            os.path.join(
                self.checkpoints_path, 'pretrained-generator-%d.pkl' %
                int(self.gen_pretrain_num_epochs)))

        # Plot pretraining figures
        plt.plot(disc_losses, label='pretraining_caption_disc_loss')
        plt.savefig(self.figure_path + 'pretraining_caption_disc_losses.png')
        plt.clf()

        plt.plot(gen_losses, label='pretraining_gen_loss')
        plt.savefig(self.figure_path + 'pretraining_gen_losses.png')
        plt.clf()

    def _train_gan(self, cls, interp):
        criterion = nn.BCELoss()
        l2_loss = nn.MSELoss()
        l1_loss = nn.L1Loss()
        iteration = 0

        gen_losses = []
        disc_losses = []
        for epoch in tqdm(range(self.num_epochs)):
            for sample in tqdm(self.data_loader):
                # pdb.set_trace()
                iteration += 1
                # sample.keys() = dict_keys(['right_images', 'wrong_images', 'inter_embed', 'right_embed', 'txt'])
                right_images = sample['right_images']  # 64x3x64x64
                right_embed = sample['right_embed']  # 64x1024
                wrong_images = sample['wrong_images']  # 64x3x64x64

                if is_cuda:
                    right_images = Variable(right_images.float()).cuda()
                    right_embed = Variable(right_embed.float()).cuda()
                    wrong_images = Variable(wrong_images.float()).cuda()
                else:
                    right_images = Variable(right_images.float())
                    right_embed = Variable(right_embed.float())
                    wrong_images = Variable(wrong_images.float())

                real_labels = torch.ones(right_images.size(0))
                fake_labels = torch.zeros(right_images.size(0))

                # ======== One sided label smoothing ==========
                # Helps preventing the discriminator from overpowering the
                # generator adding penalty when the discriminator is too confident
                # =============================================
                smoothed_real_labels = torch.FloatTensor(
                    Utils.smooth_label(real_labels.numpy(), -0.1))

                if is_cuda:
                    real_labels = Variable(real_labels).cuda()
                    smoothed_real_labels = Variable(
                        smoothed_real_labels).cuda()
                    fake_labels = Variable(fake_labels).cuda()
                else:
                    real_labels = Variable(real_labels)
                    smoothed_real_labels = Variable(smoothed_real_labels)
                    fake_labels = Variable(fake_labels)

                # Train the discriminator
                self.discriminator.zero_grad()
                outputs, activation_real = self.discriminator(
                    right_images, right_embed)
                real_loss = criterion(outputs, smoothed_real_labels)
                real_score = outputs

                if cls:
                    outputs, _ = self.discriminator(wrong_images, right_embed)
                    wrong_loss = criterion(outputs, fake_labels)
                    wrong_score = outputs

                if is_cuda:
                    noise = Variable(torch.randn(right_images.size(0),
                                                 100)).cuda()
                else:
                    noise = Variable(torch.randn(right_images.size(0), 100))
                noise = noise.view(noise.size(0), 100, 1, 1)
                fake_images = self.generator(right_embed, noise)
                outputs, _ = self.discriminator(fake_images, right_embed)
                fake_loss = criterion(outputs, fake_labels)
                fake_score = outputs

                if cls:
                    d_loss = real_loss + 0.5 * wrong_loss + 0.5 * fake_loss
                else:
                    d_loss = real_loss + fake_loss

                d_loss.backward()
                self.optimD.step()

                # Train the generator
                self.generator.zero_grad()
                if is_cuda:
                    noise = Variable(torch.randn(right_images.size(0),
                                                 100)).cuda()
                else:
                    noise = Variable(torch.randn(right_images.size(0), 100))

                noise = noise.view(noise.size(0), 100, 1, 1)
                fake_images = self.generator(right_embed, noise)
                outputs, activation_fake = self.discriminator(
                    fake_images, right_embed)
                _, activation_real = self.discriminator(
                    right_images, right_embed)

                activation_fake = torch.mean(activation_fake, 0)
                activation_real = torch.mean(activation_real, 0)

                #======= Generator Loss function============
                # This is a customized loss function, the first term is the regular cross entropy loss
                # The second term is feature matching loss, this measure the distance between the real and generated
                # images statistics by comparing intermediate layers activations
                # The third term is L1 distance between the generated and real images, this is helpful for the conditional case
                # because it links the embedding feature vector directly to certain pixel values.
                #===========================================
                g_loss = criterion(outputs, real_labels)
                # \
                # + self.l2_coef * l2_loss(activation_fake, activation_real.detach()) \
                # + self.l1_coef * l1_loss(fake_images, right_images)

                if (interp):
                    """ GAN INT loss"""
                    # pdb.set_trace()
                    # print('iter {}, size {}, right {}'.format(iteration, self.batch_size, right_embed.size()))i
                    available_batch_size = int(right_embed.size(0))
                    first_part = right_embed[:int(available_batch_size / 2), :]
                    second_part = right_embed[int(available_batch_size /
                                                  2):, :]
                    interp_embed = (first_part + second_part) * 0.5

                    if is_cuda:
                        noise = Variable(
                            torch.randn(int(available_batch_size / 2),
                                        100)).cuda()
                    else:
                        noise = Variable(
                            torch.randn(int(available_batch_size), 100))

                    interp_real_labels = torch.ones(
                        int(available_batch_size / 2))
                    if is_cuda:
                        interp_real_labels = Variable(
                            interp_real_labels).cuda()
                    else:
                        interp_real_labels = Variable(interp_real_labels)

                    fake_images = self.generator(interp_embed, noise)
                    outputs, activation_fake = self.discriminator(
                        fake_images, interp_embed)
                    g_int_loss = criterion(outputs, interp_real_labels)
                    g_loss = g_loss + 0.2 * g_int_loss

                g_loss.backward()
                self.optimG.step()

                gen_losses.append(g_loss.data[0])
                disc_losses.append(d_loss.data[0])

            with open('gen.pkl', 'wb') as f_gen, open('disc.pkl',
                                                      'wb') as f_disc:
                pickle.dump(gen_losses, f_gen)
                pickle.dump(disc_losses, f_disc)

            x = list(range(len(gen_losses)))
            plt.plot(x, gen_losses, 'g-', label='gen loss')
            plt.plot(x, disc_losses, 'b-', label='disc loss')
            plt.legend()
            plt.savefig('gen_vs_disc_.png')
            plt.clf()

            # if (epoch) % 10 == 0:
            if (epoch) % 50 == 0:
                Utils.save_checkpoint(self.discriminator, self.generator,
                                      self.checkpoints_path, self.save_path,
                                      epoch)

    def _train_stack_gan(self, cls, interp):

        criterion = nn.BCELoss()
        l2_loss = nn.MSELoss()
        l1_loss = nn.L1Loss()
        iteration = 0

        # cycle gan params
        lambda_a = 2
        lambda_b = 2
        mle_criterion = nn.CrossEntropyLoss()

        gen_losses = []
        disc_losses = []
        cycle_a_losses = []
        for epoch in tqdm(range(self.num_epochs)):
            for sample in tqdm(self.data_loader):
                # pdb.set_trace()
                iteration += 1
                # sample.keys() = dict_keys(['right_images', 'wrong_images', 'inter_embed', 'right_embed', 'txt'])
                right_images = sample['right_images']  # 64x3x64x64
                right_embed = sample['right_embed']  # 64x1024
                wrong_images = sample['wrong_images']  # 64x3x64x64
                right_images128 = sample['right_images128']  # 64x3x128x128
                wrong_images128 = sample['wrong_images128']  # 64x3x128x128
                right_captions = sample['captions']
                right_lengths = sample['lengths']

                if is_cuda:
                    right_images = Variable(right_images.float()).cuda()
                    right_embed = Variable(right_embed.float()).cuda()
                    wrong_images = Variable(wrong_images.float()).cuda()
                    right_images128 = Variable(right_images128.float()).cuda()
                    wrong_images128 = Variable(wrong_images128.float()).cuda()
                    right_captions = Variable(right_captions.long()).cuda()
                else:
                    right_images = Variable(right_images.float())
                    right_embed = Variable(right_embed.float())
                    wrong_images = Variable(wrong_images.float())
                    right_images128 = Variable(right_images128.float())
                    wrong_images128 = Variable(wrong_images128.float())
                    right_captions = Variable(right_captions.long())

                real_labels = torch.ones(right_images.size(0))
                fake_labels = torch.zeros(right_images.size(0))

                # ======== One sided label smoothing ==========
                # Helps preventing the discriminator from overpowering the
                # generator adding penalty when the discriminator is too confident
                # =============================================
                smoothed_real_labels = torch.FloatTensor(
                    Utils.smooth_label(real_labels.numpy(), -0.1))

                if is_cuda:
                    real_labels = Variable(real_labels).cuda()
                    smoothed_real_labels = Variable(
                        smoothed_real_labels).cuda()
                    fake_labels = Variable(fake_labels).cuda()
                else:
                    real_labels = Variable(real_labels)
                    smoothed_real_labels = Variable(smoothed_real_labels)
                    fake_labels = Variable(fake_labels)

                # Train the discriminator
                self.discriminator.zero_grad()

                # ------------------- Training D stage 1 -------------------------------
                outputs, activation_real = self.discriminator(
                    right_images, right_embed)
                real_loss = criterion(outputs, smoothed_real_labels)
                real_score = outputs

                if cls:
                    outputs, _ = self.discriminator(wrong_images, right_embed)
                    wrong_loss = criterion(outputs, fake_labels)
                    wrong_score = outputs

                if is_cuda:
                    noise = Variable(torch.randn(right_images.size(0),
                                                 100)).cuda()
                else:
                    noise = Variable(torch.randn(right_images.size(0), 100))
                noise = noise.view(noise.size(0), 100, 1, 1)
                fake_images = self.generator(right_embed, noise)
                outputs, _ = self.discriminator(fake_images, right_embed)
                fake_loss = criterion(outputs, fake_labels)
                fake_score = outputs

                if cls:
                    d_loss = real_loss + 0.5 * wrong_loss + 0.5 * fake_loss
                else:
                    d_loss = real_loss + fake_loss

                d_loss.backward()
                self.optimD.step()

                # -------------------- Training G stage 1 -------------------------------
                self.generator.zero_grad()
                self.discriminator.zero_grad()
                if is_cuda:
                    noise = Variable(torch.randn(right_images.size(0),
                                                 100)).cuda()
                else:
                    noise = Variable(torch.randn(right_images.size(0), 100))

                noise = noise.view(noise.size(0), 100, 1, 1)
                fake_images = self.generator(right_embed, noise)
                outputs, activation_fake = self.discriminator(
                    fake_images, right_embed)

                g_loss = criterion(outputs, real_labels)

                if (interp):
                    """ GAN INT loss"""
                    available_batch_size = int(right_embed.size(0))
                    first_part = right_embed[:int(available_batch_size / 2), :]
                    second_part = right_embed[int(available_batch_size /
                                                  2):, :]
                    interp_embed = (first_part + second_part) * 0.5

                    if is_cuda:
                        noise = Variable(
                            torch.randn(int(available_batch_size / 2),
                                        100)).cuda()
                    else:
                        noise = Variable(
                            torch.randn(int(available_batch_size), 100))

                    noise = noise.view(noise.size(0), 100, 1, 1)

                    interp_real_labels = torch.ones(
                        int(available_batch_size / 2))
                    if is_cuda:
                        interp_real_labels = Variable(
                            interp_real_labels).cuda()
                    else:
                        interp_real_labels = Variable(interp_real_labels)

                    fake_images = self.generator(interp_embed, noise)
                    outputs, activation_fake = self.discriminator(
                        fake_images, interp_embed)
                    g_int_loss = criterion(outputs, interp_real_labels)
                    g_loss = g_loss + 0.2 * g_int_loss

                g_loss.backward()
                self.optimG.step()

                # -------------------- Training D stage 2 -------------------------------
                self.discriminator2.zero_grad()
                outputs = self.discriminator2(right_images128, right_embed)
                real_loss = criterion(outputs, smoothed_real_labels)
                real_score = outputs

                if cls:
                    outputs = self.discriminator2(wrong_images128, right_embed)
                    wrong_loss = criterion(outputs, fake_labels)
                    wrong_score = outputs

                if is_cuda:
                    noise = Variable(torch.randn(right_images.size(0),
                                                 100)).cuda()
                else:
                    noise = Variable(torch.randn(right_images.size(0), 100))

                noise = noise.view(noise.size(0), 100, 1, 1)
                fake_images_v1 = self.generator(right_embed, noise)
                fake_images_v1 = fake_images_v1.detach()
                fake_images = self.generator2(fake_images_v1, right_embed)
                fake_images = fake_images.detach()
                outputs = self.discriminator2(fake_images, right_embed)
                fake_loss = criterion(outputs, fake_labels)
                fake_score = outputs

                if cls:
                    d_loss2 = real_loss + 0.5 * wrong_loss + 0.5 * fake_loss
                else:
                    d_loss2 = real_loss + fake_loss

                d_loss2.backward()
                self.optimD2.step()

                # -------------------- Training G stage 2 -------------------------------
                self.generator2.zero_grad()
                self.discriminator2.zero_grad()
                if is_cuda:
                    noise = Variable(torch.randn(right_images.size(0),
                                                 100)).cuda()
                else:
                    noise = Variable(torch.randn(right_images.size(0), 100))

                noise = noise.view(noise.size(0), 100, 1, 1)
                fake_images_v1 = self.generator(right_embed, noise)
                fake_images_v1 = fake_images_v1.detach()
                fake_images = self.generator2(fake_images_v1, right_embed)
                outputs = self.discriminator2(fake_images, right_embed)

                g_loss2 = criterion(outputs, real_labels)
                g_loss2.backward()
                self.optimG2.step()

                gen_losses.append(g_loss2.data[0])
                disc_losses.append(d_loss2.data[0])

                # Generate caption with caption GAN (inverse GAN)
                # fake_images.requires_grad = False # freeze the caption generator
                self.caption_generator.zero_grad()
                sampled_captions, _ = self.caption_generator.forward(
                    fake_images, right_captions, right_lengths)
                targets = pack_padded_sequence(right_captions,
                                               right_lengths,
                                               batch_first=True)[0]
                loss_cycle_A = mle_criterion(sampled_captions,
                                             targets) * lambda_a
                loss_cycle_A.backward()
                self.optimG2.step()
                self.optim_captionG.step()
                cycle_a_losses.append(loss_cycle_A.data[0])

            with open('gen.pkl', 'wb') as f_gen, open('disc.pkl',
                                                      'wb') as f_disc:
                pickle.dump(gen_losses, f_gen)
                pickle.dump(disc_losses, f_disc)

            if (epoch + 1) % 10 == 0:
                # if (epoch+1) % 5 == 0:
                Utils.save_checkpoint(self.discriminator, self.generator,
                                      self.checkpoints_path, self.save_path,
                                      epoch + 1)
                Utils.save_checkpoint(self.discriminator2, self.generator2,
                                      self.checkpoints_path, self.save_path,
                                      epoch + 1, False, 2)
                torch.save(
                    self.caption_discriminator.state_dict(),
                    os.path.join(self.checkpoints_path,
                                 'cycle_caption_disc-%d.pkl' % (epoch + 1)))
                torch.save(
                    self.caption_generator.state_dict(),
                    os.path.join(self.checkpoints_path,
                                 'cycle_caption_gen-%d.pkl' % (epoch + 1)))

        # Plot pretraining figures
        plt.plot(disc_losses, label='stage 1 disc losses')
        plt.savefig(self.figure_path + 'stage_1_disc_losses.png')
        plt.clf()

        plt.plot(gen_losses, label='stage_1_gen_loss')
        plt.savefig(self.figure_path + 'stage_1_gen_losses.png')
        plt.clf()

        plt.plot(disc_losses, label='cycle_a_losses')
        plt.savefig(self.figure_path + 'cycle_a_losses.png')
        plt.clf()

    def predict(self, gan_type='gan'):
        torch.manual_seed(7)
        count = 0
        for sample in self.data_loader:
            right_images = sample['right_images']
            right_embed = sample['right_embed']
            txt = sample['txt']

            if not os.path.exists('results/{0}'.format(self.save_path)):
                os.makedirs('results/{0}'.format(self.save_path))

            if is_cuda:
                right_images = Variable(right_images.float()).cuda()
                right_embed = Variable(right_embed.float()).cuda()
            else:
                right_images = Variable(right_images.float())
                right_embed = Variable(right_embed.float())

            # Train the generator
            if is_cuda:
                noise = Variable(torch.randn(right_images.size(0), 100)).cuda()
            else:
                noise = Variable(torch.randn(right_images.size(0), 100))
            noise = noise.view(noise.size(0), 100, 1, 1)
            fake_images = self.generator(right_embed, noise)

            if (gan_type == 'stackgan'):
                fake_images = self.generator2(fake_images, right_embed)

            for image, t in zip(fake_images, txt):
                im = Image.fromarray(
                    image.data.mul_(127.5).add_(127.5).byte().permute(
                        1, 2, 0).cpu().numpy())
                im.save('results/{0}/{1}.jpg'.format(self.save_path,
                                                     t.replace("/", "")[:200]))
                print(t)
            count += 1
            if count == 1:
                break
Esempio n. 2
0
    def __init__(self, type, dataset, split, lr, diter, vis_screen, save_path,
                 l1_coef, l2_coef, pre_trained_gen, pre_trained_disc,
                 batch_size, num_workers, epochs, pre_trained_disc_B,
                 pre_trained_gen_B):
        with open('config.yaml', 'r') as f:
            config = yaml.load(f)

        # forward gan
        if is_cuda:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan').cuda())
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan').cuda())
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan').cuda())
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan').cuda())
        else:
            self.generator = torch.nn.DataParallel(
                gan_factory.generator_factory('gan'))
            self.discriminator = torch.nn.DataParallel(
                gan_factory.discriminator_factory('gan'))
            self.generator2 = torch.nn.DataParallel(
                gan_factory.generator_factory('stage2_gan'))
            self.discriminator2 = torch.nn.DataParallel(
                gan_factory.discriminator_factory('stage2_gan'))

        if pre_trained_disc:
            self.discriminator.load_state_dict(torch.load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(torch.load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        if pre_trained_disc_B:
            self.discriminator2.load_state_dict(torch.load(pre_trained_disc_B))
        else:
            self.discriminator2.apply(Utils.weights_init)

        if pre_trained_gen_B:
            self.generator2.load_state_dict(torch.load(pre_trained_gen_B))
        else:
            self.generator2.apply(Utils.weights_init)

        if dataset == 'birds':
            with open('./data/birds_vocab.pkl', 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['birds_dataset_path'],
                                             dataset_type='birds',
                                             vocab=self.vocab,
                                             split=split)
        elif dataset == 'flowers':
            with open('./data/flowers_vocab.pkl', 'rb') as f:
                self.vocab = pickle.load(f)
            self.dataset = Text2ImageDataset(config['flowers_dataset_path'],
                                             dataset_type='flowers',
                                             vocab=self.vocab,
                                             split=split)
        else:
            print(
                'Dataset not supported, please select either birds or flowers.'
            )
            exit()

        self.noise_dim = 100
        self.batch_size = batch_size
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs
        self.DITER = diter
        self.num_workers = num_workers

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=self.num_workers,
                                      collate_fn=collate_fn)

        self.optimD = torch.optim.Adam(self.discriminator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))
        self.optimG = torch.optim.Adam(self.generator.parameters(),
                                       lr=self.lr,
                                       betas=(self.beta1, 0.999))

        self.optimD2 = torch.optim.Adam(self.discriminator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))
        self.optimG2 = torch.optim.Adam(self.generator2.parameters(),
                                        lr=self.lr,
                                        betas=(self.beta1, 0.999))

        self.checkpoints_path = './checkpoints/'
        self.save_path = save_path
        self.type = type

        # TODO: put these as runtime.py params
        self.embed_size = 256
        self.hidden_size = 512
        self.num_layers = 1

        self.gen_pretrain_num_epochs = 100
        self.disc_pretrain_num_epochs = 20

        self.figure_path = './figures/'
        if is_cuda:
            self.caption_generator = CaptionGenerator(self.embed_size,
                                                      self.hidden_size,
                                                      len(self.vocab),
                                                      self.num_layers).cuda()
            self.caption_discriminator = CaptionDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers).cuda()
        else:
            self.caption_generator = CaptionGenerator(self.embed_size,
                                                      self.hidden_size,
                                                      len(self.vocab),
                                                      self.num_layers)
            self.caption_discriminator = CaptionDiscriminator(
                self.embed_size, self.hidden_size, len(self.vocab),
                self.num_layers)

        pretrained_caption_gen = './checkpoints/pretrained-generator-20.pkl'
        pretrained_caption_disc = './checkpoints/pretrained-discriminator-5.pkl'

        if os.path.exists(pretrained_caption_gen):
            print('loaded pretrained caption generator')
            self.caption_generator.load_state_dict(
                torch.load(pretrained_caption_gen))

        if os.path.exists(pretrained_caption_disc):
            print('loaded pretrained caption discriminator')
            self.caption_discriminator.load_state_dict(
                torch.load(pretrained_caption_disc))

        self.optim_captionG = torch.optim.Adam(
            list(self.caption_generator.parameters()))
        self.optim_captionD = torch.optim.Adam(
            list(self.caption_discriminator.parameters()))
Esempio n. 3
0
    def _pretrain_caption(self):

        # Create model directory
        if not os.path.exists(self.checkpoints_path):
            os.makedirs(self.checkpoints_path)

        if not os.path.exists(self.figure_path):
            os.makedirs(self.figure_path)

        # Build the models (Gen)
        generator = CaptionGenerator(self.embed_size, self.hidden_size,
                                     len(self.vocab), self.num_layers)

        # Build the models (Disc)
        discriminator = CaptionDiscriminator(self.embed_size, self.hidden_size,
                                             len(self.vocab), self.num_layers)

        if torch.cuda.is_available():
            generator.cuda()
            discriminator.cuda()

        # Loss and Optimizer (Gen)
        mle_criterion = nn.CrossEntropyLoss()
        params_gen = list(generator.parameters())
        optimizer_gen = torch.optim.Adam(params_gen)

        # Loss and Optimizer (Disc)
        params_disc = list(discriminator.parameters())
        optimizer_disc = torch.optim.Adam(params_disc)

        disc_losses = []
        gen_losses = []
        for epoch in tqdm(
                range(
                    max([
                        int(self.gen_pretrain_num_epochs),
                        int(self.disc_pretrain_num_epochs)
                    ]))):
            for sample in self.data_loader:
                images = sample['right_images128']  # 64x3x128x128
                captions = sample['captions']
                lengths = sample['lengths']
                wrong_captions = sample['wrong_captions']
                wrong_lengths = sample['wrong_lengths']

                images = to_var(images, volatile=True)
                captions = to_var(captions)
                wrong_captions = to_var(wrong_captions)

                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                if epoch < int(self.gen_pretrain_num_epochs):
                    generator.zero_grad()
                    outputs, _ = generator(images, captions, lengths)
                    loss_gen = mle_criterion(outputs, targets)
                    gen_losses.append(loss_gen.cpu().data.numpy()[0])
                    loss_gen.backward()
                    optimizer_gen.step()

                if epoch < int(self.disc_pretrain_num_epochs):
                    discriminator.zero_grad()
                    rewards_real = discriminator(images, captions, lengths)
                    # rewards_fake = discriminator(images, sampled_captions, sampled_lengths)
                    rewards_wrong = discriminator(images, wrong_captions,
                                                  wrong_lengths)
                    real_loss = -torch.mean(torch.log(rewards_real))
                    # fake_loss = -torch.mean(torch.clamp(torch.log(1 - rewards_fake), min=-1000))
                    wrong_loss = -torch.mean(
                        torch.clamp(torch.log(1 - rewards_wrong), min=-1000))
                    loss_disc = real_loss + wrong_loss  # + fake_loss, no fake_loss because this is pretraining

                    disc_losses.append(loss_disc.cpu().data.numpy()[0])
                    loss_disc.backward()
                    optimizer_disc.step()

        # Save pretrained models
        torch.save(
            discriminator.state_dict(),
            os.path.join(
                self.checkpoints_path, 'pretrained-discriminator-%d.pkl' %
                int(self.disc_pretrain_num_epochs)))
        torch.save(
            generator.state_dict(),
            os.path.join(
                self.checkpoints_path, 'pretrained-generator-%d.pkl' %
                int(self.gen_pretrain_num_epochs)))

        # Plot pretraining figures
        plt.plot(disc_losses, label='pretraining_caption_disc_loss')
        plt.savefig(self.figure_path + 'pretraining_caption_disc_losses.png')
        plt.clf()

        plt.plot(gen_losses, label='pretraining_gen_loss')
        plt.savefig(self.figure_path + 'pretraining_gen_losses.png')
        plt.clf()