def main():

    data_loader = DataLoader(data_path, batch_size)
    generator = Generator(vocab_size, g_emb_dim, g_hidden_dim)
    discriminator = Discriminator(vocab_size, d_hidden_dim)

    gen_optimizer = optim.Adam(generator.parameters(), lr=0.0001)
    disc_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)

    bce_criterion = nn.BCELoss()

    if (opt.cuda):
        generator.cuda()
        discriminator.cuda()

    for i in tqdm(range(total_epochs)):
        print("EPOCH:", i)
        all_G_rewards, all_D_losses = train_gan_epoch(discriminator, generator,
                                                      data_loader,
                                                      gen_optimizer,
                                                      disc_optimizer,
                                                      bce_criterion)

        if (i % 3 == 0):
            sample = generator.sample(batch_size, g_seq_length)
            with open('./data/reinforce_gan_data_epoch' + str(i) + '.txt',
                      'w') as f:
                all_strings = []
                for each_str in data_loader.convert_to_char(sample):
                    all_strings.append(each_str)
                    f.write(each_str + '\n')
                print("Goodness string:",
                      Utils.get_data_goodness_score(all_strings))

    sample = generator.sample(batch_size, g_seq_length)

    with open('./data/reinforce_gan_final_data.txt', 'w') as f:
        for each_str in data_loader.convert_to_char(sample):
            f.write(each_str + '\n')

    plt.plot(list(range(len(all_G_rewards))), all_G_rewards)
    plt.plot(list(range(len(all_D_losses))), all_D_losses)
    plt.savefig('reward_and_D_loss.png')
Пример #2
0
def main():

    data_loader = DataLoader(data_path, batch_size)
    generator = Generator(vocab_size, g_emb_dim, g_hidden_dim)
    discriminator = Discriminator(vocab_size, d_hidden_dim)

    gen_optimizer = optim.Adam(generator.parameters(), lr=0.0001)
    disc_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)

    bce_criterion = nn.BCELoss()
    gen_criterion = nn.NLLLoss(size_average=False)

    if (opt.cuda):
        generator.cuda()

    pretrain_lstm(generator, data_loader, gen_optimizer, gen_criterion, 10)

    all_G_losses = []
    all_D_losses = []
    for i in tqdm(range(total_epochs)):
        g_losses, d_losses = train_gan_epoch(discriminator, generator,
                                             data_loader, gen_optimizer,
                                             disc_optimizer, bce_criterion)
        all_G_losses += g_losses
        all_D_losses += d_losses

    sample = generator.sample(batch_size, g_seq_length)

    print(generator)

    with open('./data/gumbel_softmax_gan_gen.txt', 'w') as f:
        for each_str in data_loader.convert_to_char(sample):
            f.write(each_str + '\n')

    gen_file_name = 'gen_gumbel_softmax_' + str(total_epochs) + '.pth'
    disc_file_name = 'disc_gumbel_softmax_' + str(total_epochs) + '.pth'

    Utils.save_checkpoints(checkpoint_dir, gen_file_name, generator)
    Utils.save_checkpoints(checkpoint_dir, disc_file_name, discriminator)

    plt.plot(list(range(len(all_G_losses))),
             all_G_losses,
             'g-',
             label='gen loss')
    plt.plot(list(range(len(all_D_losses))),
             all_D_losses,
             'b-',
             label='disc loss')
    plt.legend()
    plt.show()
Пример #3
0
def main():

    data_loader = DataLoader(data_path, batch_size)
    generator = PlainLSTM(vocab_size, g_emb_dim, g_hidden_dim)
    optimizer = optim.Adam(generator.parameters())
    if (opt.cuda):
        generator.cuda()

    losses_array = []
    for epoch in tqdm(range(total_epochs)):
        for data, target in data_loader:

            total_loss = 0.0
            total_words = 0.0
            data = Variable(data)  #dim=batch_size x sequence_length e.g: 16x15
            target = Variable(
                target)  #dim=batch_size x sequence_length e.g: 16x15
            if opt.cuda:
                data, target = data.cuda(), target.cuda()
            pred = generator(data)

            target = target.view(-1)
            pred = pred.view(-1, vocab_size)

            gen_criterion = nn.NLLLoss(size_average=False)
            loss = gen_criterion(pred, target)
            total_loss += loss.data[0]
            total_words += data.size(0) * data.size(1)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses_array.append(total_loss)

        data_loader.reset()

    sample = generator.sample(batch_size, g_seq_length)

    with open('./data/lstm_mle_gen_data.txt', 'w') as f:
        for each_str in data_loader.convert_to_char(sample):
            f.write(each_str + '\n')
    plt.plot(losses_array)
    plt.show()
Пример #4
0
def main(opt):

    cuda = opt.cuda
    visualize = opt.visualize
    print(f"cuda = {cuda}, visualize = {opt.visualize}")
    if visualize:
        if PRE_EPOCH_GEN > 0:
            pretrain_G_score_logger = VisdomPlotLogger(
                'line', opts={'title': 'Pre-train G Goodness Score'})
        if PRE_EPOCH_DIS > 0:
            pretrain_D_loss_logger = VisdomPlotLogger(
                'line', opts={'title': 'Pre-train D Loss'})
        adversarial_G_score_logger = VisdomPlotLogger(
            'line',
            opts={
                'title': f'Adversarial G {GD} Goodness Score',
                'Y': '{0, 13}',
                'X': '{0, TOTAL_BATCH}'
            })
        if CHECK_VARIANCE:
            G_variance_logger = VisdomPlotLogger(
                'line', opts={'title': f'Adversarial G {GD} Variance'})
        G_text_logger = VisdomTextLogger(update_type='APPEND')
        adversarial_D_loss_logger = VisdomPlotLogger(
            'line', opts={'title': 'Adversarial Batch D Loss'})

    # Define Networks
    generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, cuda)
    n_gen = Variable(torch.Tensor([get_n_params(generator)]))
    use_cuda = False
    if cuda:
        n_gen = n_gen.cuda()
        use_cuda = True
    print('Number of parameters in the generator: {}'.format(n_gen))
    discriminator = LSTMDiscriminator(d_num_class, VOCAB_SIZE,
                                      d_lstm_hidden_dim, use_cuda)
    c_phi_hat = AnnexNetwork(d_num_class, VOCAB_SIZE, d_emb_dim,
                             c_filter_sizes, c_num_filters, d_dropout,
                             BATCH_SIZE, g_sequence_len)
    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        c_phi_hat = c_phi_hat.cuda()

    # Generate toy data using target lstm
    print('Generating data ...')

    # Load data from file
    gen_data_iter = DataLoader(POSITIVE_FILE, BATCH_SIZE)

    gen_criterion = nn.NLLLoss(size_average=False)
    gen_optimizer = optim.Adam(generator.parameters())
    if cuda:
        gen_criterion = gen_criterion.cuda()
    # 预训练Generator
    # Pretrain Generator using MLE
    pre_train_scores = []
    if MLE:
        print('Pretrain with MLE ...')
        for epoch in range(int(np.ceil(PRE_EPOCH_GEN))):
            loss = train_epoch(generator, gen_data_iter, gen_criterion,
                               gen_optimizer, PRE_EPOCH_GEN, epoch, cuda)
            print('Epoch [%d] Model Loss: %f' % (epoch, loss))
            samples = generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                                       EVAL_FILE)
            eval_iter = DataLoader(EVAL_FILE, BATCH_SIZE)
            generated_string = eval_iter.convert_to_char(samples)
            print(generated_string)
            eval_score = get_data_goodness_score(generated_string, SPACES)
            if SPACES == False:
                kl_score = get_data_freq(generated_string)
            else:
                kl_score = -1
            freq_score = get_char_freq(generated_string, SPACES)
            pre_train_scores.append(eval_score)
            print('Epoch [%d] Generation Score: %f' % (epoch, eval_score))
            print('Epoch [%d] KL Score: %f' % (epoch, kl_score))
            print('Epoch [{}] Character distribution: {}'.format(
                epoch, list(freq_score)))

            torch.save(
                generator.state_dict(),
                f"checkpoints/MLE_space_{SPACES}_length_{SEQ_LEN}_preTrainG_epoch_{epoch}.pth"
            )

            if visualize:
                pretrain_G_score_logger.log(epoch, eval_score)
    else:
        generator.load_state_dict(torch.load(weights_path))

    # Finishing training with MLE
    if GD == "MLE":
        for epoch in range(3 * int(GENERATED_NUM / BATCH_SIZE)):
            loss = train_epoch_batch(generator, gen_data_iter, gen_criterion,
                                     gen_optimizer, PRE_EPOCH_GEN, epoch,
                                     int(GENERATED_NUM / BATCH_SIZE), cuda)
            print('Epoch [%d] Model Loss: %f' % (epoch, loss))
            samples = generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                                       EVAL_FILE)
            eval_iter = DataLoader(EVAL_FILE, BATCH_SIZE)
            generated_string = eval_iter.convert_to_char(samples)
            print(generated_string)
            eval_score = get_data_goodness_score(generated_string, SPACES)
            if SPACES == False:
                kl_score = get_data_freq(generated_string)
            else:
                kl_score = -1
            freq_score = get_char_freq(generated_string, SPACES)
            pre_train_scores.append(eval_score)
            print('Epoch [%d] Generation Score: %f' % (epoch, eval_score))
            print('Epoch [%d] KL Score: %f' % (epoch, kl_score))
            print('Epoch [{}] Character distribution: {}'.format(
                epoch, list(freq_score)))

            torch.save(
                generator.state_dict(),
                f"checkpoints/MLE_space_{SPACES}_length_{SEQ_LEN}_preTrainG_epoch_{epoch}.pth"
            )

            if visualize:
                pretrain_G_score_logger.log(epoch, eval_score)
    # 预训练Discriminator
    # Pretrain Discriminator
    dis_criterion = nn.NLLLoss(size_average=False)
    dis_optimizer = optim.Adam(discriminator.parameters())
    if opt.cuda:
        dis_criterion = dis_criterion.cuda()
    print('Pretrain Discriminator ...')
    for epoch in range(PRE_EPOCH_DIS):
        samples = generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                                   NEGATIVE_FILE)
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE,
                                    SEQ_LEN)
        for _ in range(PRE_ITER_DIS):
            loss = train_epoch(discriminator, dis_data_iter, dis_criterion,
                               dis_optimizer, 1, 1, cuda)
            print('Epoch [%d], loss: %f' % (epoch, loss))
            if visualize:
                pretrain_D_loss_logger.log(epoch, loss)
    # 对抗训练
    # Adversarial Training
    rollout = Rollout(generator, UPDATE_RATE)
    print('#####################################################')
    print('Start Adversarial Training...\n')

    gen_gan_loss = GANLoss()
    gen_gan_optm = optim.Adam(generator.parameters())
    if cuda:
        gen_gan_loss = gen_gan_loss.cuda()
    gen_criterion = nn.NLLLoss(size_average=False)
    if cuda:
        gen_criterion = gen_criterion.cuda()

    dis_criterion = nn.NLLLoss(size_average=False)
    dis_criterion_bce = nn.BCELoss()
    dis_optimizer = optim.Adam(discriminator.parameters())
    if cuda:
        dis_criterion = dis_criterion.cuda()

    c_phi_hat_loss = VarianceLoss()
    if cuda:
        c_phi_hat_loss = c_phi_hat_loss.cuda()
    c_phi_hat_optm = optim.Adam(c_phi_hat.parameters())

    gen_scores = pre_train_scores

    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(G_STEPS):
            samples = generator.sample(BATCH_SIZE, g_sequence_len)
            # samples has size (BS, sequence_len)
            # Construct the input to the generator, add zeros before samples and delete the last column
            zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor)
            if samples.is_cuda:
                zeros = zeros.cuda()
            inputs = Variable(
                torch.cat([zeros, samples.data], dim=1)[:, :-1].contiguous())
            targets = Variable(samples.data).contiguous().view((-1, ))
            if opt.cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
            # Calculate the reward
            rewards = rollout.get_reward(samples, discriminator, VOCAB_SIZE,
                                         cuda)
            rewards = Variable(torch.Tensor(rewards))
            if cuda:
                rewards = torch.exp(rewards.cuda()).contiguous().view((-1, ))
            rewards = torch.exp(rewards)
            # rewards has size (BS)
            prob = generator.forward(inputs)
            # prob has size (BS*sequence_len, VOCAB_SIZE)
            # 3.a
            theta_prime = g_output_prob(prob)
            # theta_prime has size (BS*sequence_len, VOCAB_SIZE)
            # 3.e and f
            c_phi_z_ori, c_phi_z_tilde_ori = c_phi_out(
                GD,
                c_phi_hat,
                theta_prime,
                discriminator,
                temperature=DEFAULT_TEMPERATURE,
                eta=DEFAULT_ETA,
                cuda=cuda)
            c_phi_z_ori = torch.exp(c_phi_z_ori)
            c_phi_z_tilde_ori = torch.exp(c_phi_z_tilde_ori)
            c_phi_z = torch.sum(c_phi_z_ori[:, 1]) / BATCH_SIZE
            c_phi_z_tilde = -torch.sum(c_phi_z_tilde_ori[:, 1]) / BATCH_SIZE
            if opt.cuda:
                c_phi_z = c_phi_z.cuda()
                c_phi_z_tilde = c_phi_z_tilde.cuda()
                c_phi_hat = c_phi_hat.cuda()
            # 3.i
            grads = []
            first_term_grads = []
            # 3.h optimization step
            # first, empty the gradient buffers
            gen_gan_optm.zero_grad()
            # first, re arrange prob
            new_prob = prob.view((BATCH_SIZE, g_sequence_len, VOCAB_SIZE))
            # 3.g new gradient loss for relax
            batch_i_grads_1 = gen_gan_loss.forward_reward_grads(
                samples, new_prob, rewards, generator, BATCH_SIZE,
                g_sequence_len, VOCAB_SIZE, cuda)
            batch_i_grads_2 = gen_gan_loss.forward_reward_grads(
                samples, new_prob, c_phi_z_tilde_ori[:, 1], generator,
                BATCH_SIZE, g_sequence_len, VOCAB_SIZE, cuda)
            # batch_i_grads_1 and batch_i_grads_2 should be of length BATCH SIZE of arrays of all the gradients
            # # 3.i
            batch_grads = batch_i_grads_1
            if GD != "REINFORCE":
                for i in range(len(batch_i_grads_1)):
                    for j in range(len(batch_i_grads_1[i])):
                        batch_grads[i][j] = torch.add(batch_grads[i][j], (-1) *
                                                      batch_i_grads_2[i][j])
            # batch_grads should be of length BATCH SIZE
            grads.append(batch_grads)
            # NOW, TRAIN THE GENERATOR
            generator.zero_grad()
            for i in range(g_sequence_len):
                # 3.g new gradient loss for relax
                cond_prob = gen_gan_loss.forward_reward(
                    i, samples, new_prob, rewards, BATCH_SIZE, g_sequence_len,
                    VOCAB_SIZE, cuda)
                c_term = gen_gan_loss.forward_reward(i, samples, new_prob,
                                                     c_phi_z_tilde_ori[:, 1],
                                                     BATCH_SIZE,
                                                     g_sequence_len,
                                                     VOCAB_SIZE, cuda)
                if GD != "REINFORCE":
                    cond_prob = torch.add(cond_prob, (-1) * c_term)
                new_prob[:, i, :].backward(cond_prob, retain_graph=True)
            # 3.h - still training the generator, with the last two terms of the RELAX equation
            if GD != "REINFORCE":
                c_phi_z.backward(retain_graph=True)
                c_phi_z_tilde.backward(retain_graph=True)
            gen_gan_optm.step()
            # 3.i
            if CHECK_VARIANCE:
                # c_phi_z term
                partial_grads = []
                for j in range(BATCH_SIZE):
                    generator.zero_grad()
                    c_phi_z_ori[j, 1].backward(retain_graph=True)
                    j_grads = []
                    for p in generator.parameters():
                        j_grads.append(p.grad.clone())
                    partial_grads.append(j_grads)
                grads.append(partial_grads)
                # c_phi_z_tilde term
                partial_grads = []
                for j in range(BATCH_SIZE):
                    generator.zero_grad()
                    c_phi_z_tilde_ori[j, 1].backward(retain_graph=True)
                    j_grads = []
                    for p in generator.parameters():
                        j_grads.append(-1 * p.grad.clone())
                    partial_grads.append(j_grads)
                grads.append(partial_grads)
                # Uncomment the below code if you want to check gradients
                """
                print('1st contribution to the gradient')
                print(grads[0][0][6])
                print('2nd contribution to the gradient')
                print(grads[1][0][6])
                print('3rd contribution to the gradient')
                print(grads[2][0][6])
                """
                #grads should be of length 3
                #grads[0] should be of length BATCH SIZE
                # 3.j
                all_grads = grads[0]
                if GD != "REINFORCE":
                    for i in range(len(grads[0])):
                        for j in range(len(grads[0][i])):
                            all_grads[i][j] = torch.add(
                                torch.add(all_grads[i][j], grads[1][i][j]),
                                grads[2][i][j])
                # all_grads should be of length BATCH_SIZE
                c_phi_hat_optm.zero_grad()
                var_loss = c_phi_hat_loss.forward(all_grads, cuda)  #/n_gen
                true_variance = c_phi_hat_loss.forward_variance(
                    all_grads, cuda)
                var_loss.backward()
                c_phi_hat_optm.step()
                print(
                    'Batch [{}] Estimate of the variance of the gradient at step {}: {}'
                    .format(total_batch, it, true_variance[0]))
                if visualize:
                    G_variance_logger.log((total_batch + it), true_variance[0])

        # Evaluate the quality of the Generator outputs
        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            samples = generate_samples(generator, BATCH_SIZE, GENERATED_NUM,
                                       EVAL_FILE)
            eval_iter = DataLoader(EVAL_FILE, BATCH_SIZE)
            generated_string = eval_iter.convert_to_char(samples)
            print(generated_string)
            eval_score = get_data_goodness_score(generated_string, SPACES)
            if SPACES == False:
                kl_score = get_data_freq(generated_string)
            else:
                kl_score = -1
            freq_score = get_char_freq(generated_string, SPACES)
            gen_scores.append(eval_score)
            print('Batch [%d] Generation Score: %f' %
                  (total_batch, eval_score))
            print('Batch [%d] KL Score: %f' % (total_batch, kl_score))
            print('Epoch [{}] Character distribution: {}'.format(
                total_batch, list(freq_score)))

            #Checkpoint & Visualize
            if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
                torch.save(
                    generator.state_dict(),
                    f'checkpoints/{GD}_G_space_{SPACES}_pretrain_{PRE_EPOCH_GEN}_batch_{total_batch}.pth'
                )
            if visualize:
                [G_text_logger.log(line) for line in generated_string]
                adversarial_G_score_logger.log(total_batch, eval_score)

        # Train the discriminator
        batch_G_loss = 0.0

        for b in range(D_EPOCHS):

            for data, _ in gen_data_iter:

                data = Variable(data)
                real_data = convert_to_one_hot(data, VOCAB_SIZE, cuda)
                real_target = Variable(torch.ones((data.size(0), 1)))
                samples = generator.sample(data.size(0),
                                           g_sequence_len)  # bs x seq_len
                fake_data = convert_to_one_hot(
                    samples, VOCAB_SIZE, cuda)  # bs x seq_len x vocab_size
                fake_target = Variable(torch.zeros((data.size(0), 1)))

                if cuda:
                    real_target = real_target.cuda()
                    fake_target = fake_target.cuda()
                    real_data = real_data.cuda()
                    fake_data = fake_data.cuda()

                real_pred = torch.exp(discriminator(real_data)[:, 1])
                fake_pred = torch.exp(discriminator(fake_data)[:, 1])

                D_real_loss = dis_criterion_bce(real_pred, real_target)
                D_fake_loss = dis_criterion_bce(fake_pred, fake_target)
                D_loss = D_real_loss + D_fake_loss
                dis_optimizer.zero_grad()
                D_loss.backward()
                dis_optimizer.step()

            gen_data_iter.reset()

            print('Batch [{}] Discriminator Loss at step and epoch {}: {}'.
                  format(total_batch, b, D_loss.data[0]))

        if visualize:
            adversarial_D_loss_logger.log(total_batch, D_loss.data[0])

    if not visualize:
        plt.plot(gen_scores)
        plt.ylim((0, 13))
        plt.title('{}_after_{}_epochs_of_pretraining'.format(
            GD, PRE_EPOCH_GEN))
        plt.show()