예제 #1
0
transform2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
    ]
)

train_faceDataset = FaceDataset(data_path+'train', data_path+'train.csv', transform2)
test_faceDataset = FaceDataset(data_path+'test', data_path+'test.csv', transform2)
train_dataloader = DataLoader(ConcatDataset([train_faceDataset, test_faceDataset]), batch_size=batch_size, num_workers=1)

netG = Generator(d, latent_size)
netD = Discriminator(d)
if cuda:
    netG = netG.cuda()
    netD = netD.cuda()
#print(netG)
#summary(netG, (1, 128))
# print(netD)
# summary(netD, (3, 64, 64))
# exit()

criterion = nn.BCELoss()

optimizerG = optim.Adam(netG.parameters(), lr=0.002, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=0.002, betas=(0.5, 0.999))

fix_noise = torch.randn(batch_size, latent_size, 1, 1).cuda()
lossG = []
lossD = []
Dx = []
예제 #2
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    if not os.path.exists(args.figure_path):
        os.makedirs(args.figure_path)

    # Image preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # Build the models (Gen)
    generator = Generator(args.embed_size, args.hidden_size, len(vocab),
                          args.num_layers)

    # Build the models (Disc)
    discriminator = Discriminator(args.embed_size, args.hidden_size,
                                  len(vocab), args.num_layers)

    if torch.cuda.is_available():
        generator.cuda()
        discriminator.cuda()

    # Loss and Optimizer (Gen)
    mle_criterion = nn.CrossEntropyLoss()
    params_gen = list(generator.parameters())
    optimizer_gen = torch.optim.Adam(params_gen)

    # Loss and Optimizer (Disc)
    params_disc = list(discriminator.parameters())
    optimizer_disc = torch.optim.Adam(params_disc)

    if int(args.pretraining) == 1:
        # Pre-training: train generator with MLE and discriminator with 3 losses (real + fake + wrong)
        total_steps = len(data_loader)
        print(total_steps)
        disc_losses = []
        gen_losses = []
        print('pre-training')
        generator.load_state_dict(torch.load(args.pretrained_gen_path))
        discriminator.load_state_dict(torch.load(args.pretrained_disc_path))
        for epoch in range(
                max([
                    int(args.gen_pretrain_num_epochs),
                    int(args.disc_pretrain_num_epochs)
                ])):
            if epoch < 5:
                continue
        # for epoch in range(max([int(args.gen_pretrain_num_epochs), int(args.disc_pretrain_num_epochs)])):
            for i, (images, captions, lengths, wrong_captions,
                    wrong_lengths) in enumerate(data_loader):
                images = to_var(images, volatile=True)
                captions = to_var(captions)
                wrong_captions = to_var(wrong_captions)
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                if epoch < int(args.gen_pretrain_num_epochs):
                    generator.zero_grad()
                    outputs, _ = generator(images, captions, lengths)
                    loss_gen = mle_criterion(outputs, targets)
                    # gen_losses.append(loss_gen.cpu().data.numpy()[0])
                    loss_gen.backward()
                    optimizer_gen.step()

                if epoch < int(args.disc_pretrain_num_epochs):
                    discriminator.zero_grad()
                    rewards_real = discriminator(images, captions, lengths)
                    # rewards_fake = discriminator(images, sampled_captions, sampled_lengths)
                    rewards_wrong = discriminator(images, wrong_captions,
                                                  wrong_lengths)
                    real_loss = -torch.mean(torch.log(rewards_real))
                    # fake_loss = -torch.mean(torch.clamp(torch.log(1 - rewards_fake), min=-1000))
                    wrong_loss = -torch.mean(
                        torch.clamp(torch.log(1 - rewards_wrong), min=-1000))
                    loss_disc = real_loss + wrong_loss  # + fake_loss, no fake_loss because this is pretraining

                    # disc_losses.append(loss_disc.cpu().data.numpy()[0])
                    loss_disc.backward()
                    optimizer_disc.step()
                if (i + 1) % args.log_step == 0:
                    print(
                        'Epoch [%d], Step [%d], Disc Loss: %.4f, Gen Loss: %.4f'
                        % (epoch + 1, i + 1, loss_disc, loss_gen))
                if (i + 1) % 500 == 0:
                    torch.save(
                        discriminator.state_dict(),
                        os.path.join(
                            args.model_path,
                            'pretrained-discriminator-%d-%d.pkl' %
                            (int(epoch) + 1, i + 1)))
                    torch.save(
                        generator.state_dict(),
                        os.path.join(
                            args.model_path, 'pretrained-generator-%d-%d.pkl' %
                            (int(epoch) + 1, i + 1)))

                    # Save pretrained models
        torch.save(
            discriminator.state_dict(),
            os.path.join(
                args.model_path, 'pretrained-discriminator-%d.pkl' %
                int(args.disc_pretrain_num_epochs)))
        torch.save(
            generator.state_dict(),
            os.path.join(
                args.model_path, 'pretrained-generator-%d.pkl' %
                int(args.gen_pretrain_num_epochs)))

        # Plot pretraining figures
        # plt.plot(disc_losses, label='pretraining_disc_loss')
        # plt.savefig(args.figure_path + 'pretraining_disc_losses.png')
        # plt.clf()
        #
        # plt.plot(gen_losses, label='pretraining_gen_loss')
        # plt.savefig(args.figure_path + 'pretraining_gen_losses.png')
        # plt.clf()

    else:
        generator.load_state_dict(torch.load(args.pretrained_gen_path))
        discriminator.load_state_dict(torch.load(args.pretrained_disc_path))

    # # Skip the rest for now
    # return

    # Train the Models
    total_step = len(data_loader)
    disc_gan_losses = []
    gen_gan_losses = []
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths, wrong_captions,
                wrong_lengths) in enumerate(data_loader):

            # Set mini-batch dataset
            images = to_var(images, volatile=True)
            captions = to_var(captions)
            wrong_captions = to_var(wrong_captions)

            generator.zero_grad()
            outputs, packed_lengths = generator(images, captions, lengths)
            outputs = PackedSequence(outputs, packed_lengths)
            outputs = pad_packed_sequence(outputs,
                                          batch_first=True)  # (b, T, V)

            Tmax = outputs[0].size(1)
            if torch.cuda.is_available():
                rewards = torch.zeros_like(outputs[0]).type(
                    torch.cuda.FloatTensor)
            else:
                rewards = torch.zeros_like(outputs[0]).type(torch.FloatTensor)

            # getting rewards from disc

# for t in tqdm(range(2, Tmax, 4)):
            for t in range(2, Tmax, 2):
                # for t in range(2, 4):
                if t >= min(
                        lengths
                ):  # TODO this makes things easier, but could min(lengths) could be too short
                    break

                gen_samples = to_var(torch.zeros(
                    (captions.size(0), Tmax)).type(torch.FloatTensor),
                                     volatile=True)
                # part 1: taken from real caption
                gen_samples[:, :t] = captions[:, :t].data

                predicted_ids, saved_states = generator.pre_compute(
                    gen_samples, t)
                # for v in range(predicted_ids.size(1)):
                v = predicted_ids
                # pdb.set_trace()
                # part 2: taken from all possible vocabs
                # gen_samples[:,t] = predicted_ids[:,v]
                gen_samples[:, t] = v
                # part 3: taken from rollouts
                gen_samples[:, t:] = generator.rollout(gen_samples, t,
                                                       saved_states)

                sampled_lengths = []
                # finding sampled_lengths
                for batch in range(int(captions.size(0))):
                    for b_t in range(Tmax):
                        if gen_samples[batch,
                                       b_t].cpu().data.numpy() == 2:  # <end>
                            sampled_lengths.append(b_t + 1)
                            break
                        elif b_t == Tmax - 1:
                            sampled_lengths.append(Tmax)

                # sort sampled_lengths
                sampled_lengths = np.array(sampled_lengths)
                sampled_lengths[::-1].sort()
                sampled_lengths = sampled_lengths.tolist()

                # get rewards from disc
                rewards[:, t, v] = discriminator(images, gen_samples.detach(),
                                                 sampled_lengths)

            # rewards = rewards.detach()
            # pdb.set_trace()
            rewards_detached = rewards.data
            rewards_detached = to_var(rewards_detached)

            loss_gen = torch.dot(outputs[0], -rewards_detached)
            # gen_gan_losses.append(loss_gen.cpu().data.numpy()[0])
            # pdb.set_trace()

            loss_gen.backward()
            optimizer_gen.step()

            # TODO get sampled_captions
            sampled_ids = generator.sample(images)
            # sampled_captions = torch.zeros_like(sampled_ids).type(torch.LongTensor)
            sampled_lengths = []
            # finding sampled_lengths
            for batch in range(int(captions.size(0))):
                for b_t in range(20):
                    #pdb.set_trace()
                    #sampled_captions[batch, b_t].data = sampled_ids[batch, b_t].cpu().data.numpy()[0]
                    if sampled_ids[batch,
                                   b_t].cpu().data.numpy() == 2:  # <end>
                        sampled_lengths.append(b_t + 1)
                        break
                    elif b_t == 20 - 1:
                        sampled_lengths.append(20)
            # sort sampled_lengths
            sampled_lengths = np.array(sampled_lengths)
            sampled_lengths[::-1].sort()
            sampled_lengths = sampled_lengths.tolist()

            # Train discriminator
            discriminator.zero_grad()
            images.volatile = False
            captions.volatile = False
            wrong_captions.volatile = False
            rewards_real = discriminator(images, captions, lengths)
            rewards_fake = discriminator(images, sampled_ids, sampled_lengths)
            rewards_wrong = discriminator(images, wrong_captions,
                                          wrong_lengths)
            real_loss = -torch.mean(torch.log(rewards_real))
            fake_loss = -torch.mean(
                torch.clamp(torch.log(1 - rewards_fake), min=-1000))
            wrong_loss = -torch.mean(
                torch.clamp(torch.log(1 - rewards_wrong), min=-1000))
            loss_disc = real_loss + fake_loss + wrong_loss

            # disc_gan_losses.append(loss_disc.cpu().data.numpy()[0])
            loss_disc.backward()
            optimizer_disc.step()

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [%d/%d], Step [%d/%d], Disc Loss: %.4f, Gen Loss: %.4f'
                    % (epoch, args.num_epochs, i, total_step, loss_disc,
                       loss_gen))

            # Save the models
            # if (i+1) % args.save_step == 0:
            if (
                    i + 1
            ) % args.log_step == 0:  # jm: saving at the last iteration instead
                torch.save(
                    generator.state_dict(),
                    os.path.join(
                        args.model_path,
                        'generator-gan-%d-%d.pkl' % (epoch + 1, i + 1)))
                torch.save(
                    discriminator.state_dict(),
                    os.path.join(
                        args.model_path,
                        'discriminator-gan-%d-%d.pkl' % (epoch + 1, i + 1)))
예제 #3
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # Build the models (Gen)
    # TODO: put these in generator
    encoder = EncoderCNN(args.embed_size)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers)

    # Build the models (Disc)
    discriminator = Discriminator(args.embed_size, args.hidden_size,
                                  len(vocab), args.num_layers)

    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()
        discriminator.cuda()

    # Loss and Optimizer (Gen)
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Loss and Optimizer (Disc)
    params_disc = list(discriminator.parameters())
    optimizer_disc = torch.optim.Adam(params_disc)

    # Train the Models
    total_step = len(data_loader)
    disc_losses = []
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths, wrong_captions,
                wrong_lengths) in enumerate(data_loader):

            # pdb.set_trace()
            # TODO: train disc before gen

            # Set mini-batch dataset
            images = to_var(images, volatile=True)
            captions = to_var(captions)
            wrong_captions = to_var(wrong_captions)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            # Forward, Backward and Optimize
            decoder.zero_grad()
            encoder.zero_grad()
            features = encoder(images)
            outputs = decoder(features, captions, lengths)

            sampled_captions = decoder.sample(features)
            # sampled_captions = torch.zeros_like(sampled_ids)
            sampled_lengths = []

            for row in range(sampled_captions.size(0)):
                for index, word_id in enumerate(sampled_captions[row, :]):
                    # pdb.set_trace()
                    word = vocab.idx2word[word_id.cpu().data.numpy()[0]]
                    # sampled_captions[row, index].data = word
                    if word == '<end>':
                        sampled_lengths.append(index + 1)
                        break
                    elif index == sampled_captions.size(1) - 1:
                        sampled_lengths.append(sampled_captions.size(1))
                        break
            sampled_lengths = np.array(sampled_lengths)
            sampled_lengths[::-1].sort()
            sampled_lengths = sampled_lengths.tolist()
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            # Train discriminator
            discriminator.zero_grad()
            rewards_real = discriminator(images, captions, lengths)
            rewards_fake = discriminator(images, sampled_captions,
                                         sampled_lengths)
            rewards_wrong = discriminator(images, wrong_captions,
                                          wrong_lengths)
            real_loss = -torch.mean(torch.log(rewards_real))
            fake_loss = -torch.mean(
                torch.clamp(torch.log(1 - rewards_fake), min=-1000))
            wrong_loss = -torch.mean(
                torch.clamp(torch.log(1 - rewards_wrong), min=-1000))
            loss_disc = real_loss + fake_loss + wrong_loss

            disc_losses.append(loss_disc.cpu().data.numpy()[0])
            loss_disc.backward()
            optimizer_disc.step()

            # print('iteration %i' % i)
            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
                    % (epoch, args.num_epochs, i, total_step, loss.data[0],
                       np.exp(loss.data[0])))

            # Save the models
            # if (i+1) % args.save_step == 0:
            if (
                    i + 1
            ) % total_step == 0:  # jm: saving at the last iteration instead
                torch.save(
                    decoder.state_dict(),
                    os.path.join(args.model_path,
                                 'decoder-%d-%d.pkl' % (epoch + 1, i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(args.model_path,
                                 'encoder-%d-%d.pkl' % (epoch + 1, i + 1)))
                torch.save(
                    discriminator.state_dict(),
                    os.path.join(
                        args.model_path,
                        'discriminator-%d-%d.pkl' % (epoch + 1, i + 1)))

                # plot at the end of every epoch
                plt.plot(disc_losses, label='disc loss')
                plt.savefig('disc_losses.png')
                plt.clf()