Beispiel #1
0
def main(pretrain_checkpoint_dir,
         train_summary_writer,
         vocab: Vocab,
         dataloader: DataLoader,
         batch_size: int = 64,
         embedding_dim: int = 256,
         seq_length: int = 3000,
         gen_seq_len: int = 3000,
         gen_rnn_units: int = 1024,
         disc_rnn_units: int = 1024,
         epochs: int = 40000,
         pretrain_epochs: int = 4000,
         learning_rate: float = 1e-4,
         rollout_num: int = 2,
         gen_pretrain: bool = False,
         disc_pretrain: bool = False,
         load_gen_weights: bool = False,
         load_disc_weights: bool = False,
         save_gen_weights: bool = True,
         save_disc_weights: bool = True,
         disc_steps: int = 3):
    gen = Generator(dataloader=dataloader,
                    vocab=vocab,
                    batch_size=batch_size,
                    embedding_dim=embedding_dim,
                    seq_length=seq_length,
                    checkpoint_dir=pretrain_checkpoint_dir,
                    rnn_units=gen_rnn_units,
                    start_token=0,
                    learning_rate=learning_rate)
    if load_gen_weights:
        gen.load_weights()
    if gen_pretrain:
        gen_pre_trainer = GenPretrainer(gen,
                                        dataloader=dataloader,
                                        vocab=vocab,
                                        pretrain_epochs=pretrain_epochs,
                                        tb_writer=train_summary_writer,
                                        learning_rate=learning_rate)
        print('Start pre-training generator...')
        gen_pre_trainer.pretrain(gen_seq_len=gen_seq_len,
                                 save_weights=save_gen_weights)

    disc = Discriminator(vocab_size=vocab.vocab_size,
                         embedding_dim=embedding_dim,
                         rnn_units=disc_rnn_units,
                         batch_size=batch_size,
                         checkpoint_dir=pretrain_checkpoint_dir,
                         learning_rate=learning_rate)
    if load_disc_weights:
        disc.load_weights()
    if disc_pretrain:
        disc_pre_trainer = DiscPretrainer(disc,
                                          gen,
                                          dataloader=dataloader,
                                          vocab=vocab,
                                          pretrain_epochs=pretrain_epochs,
                                          tb_writer=train_summary_writer,
                                          learning_rate=learning_rate)
        print('Start pre-training discriminator...')
        disc_pre_trainer.pretrain(save_disc_weights)
    rollout = Rollout(generator=gen,
                      discriminator=disc,
                      vocab=vocab,
                      batch_size=batch_size,
                      seq_length=seq_length,
                      rollout_num=rollout_num)

    with tqdm(desc='Epoch: ', total=epochs, dynamic_ncols=True) as pbar:
        for epoch in range(epochs):
            fake_samples = gen.generate()
            rewards = rollout.get_reward(samples=fake_samples)
            gen_loss = gen.train_step(fake_samples, rewards)
            real_samples, _ = dataloader.get_batch(shuffle=shuffle,
                                                   seq_length=seq_length,
                                                   batch_size=batch_size,
                                                   training=True)
            disc_loss = 0
            for i in range(disc_steps):
                disc_loss += disc.train_step(fake_samples,
                                             real_samples) / disc_steps

            with train_summary_writer.as_default():
                tf.summary.scalar('gen_train_loss', gen_loss, step=epoch)
                tf.summary.scalar('disc_train_loss', disc_loss, step=epoch)
                tf.summary.scalar('total_train_loss',
                                  disc_loss + gen_loss,
                                  step=epoch)

            pbar.set_postfix(gen_train_loss=tf.reduce_mean(gen_loss),
                             disc_train_loss=tf.reduce_mean(disc_loss),
                             total_train_loss=tf.reduce_mean(gen_loss +
                                                             disc_loss))

            if (epoch + 1) % 5 == 0 or (epoch + 1) == 1:
                print('保存weights...')
                # 保存weights
                gen.model.save_weights(gen.checkpoint_prefix)
                disc.model.save_weights(disc.checkpoint_prefix)
                # gen.model.save('gen.h5')
                # disc.model.save('disc.h5')

                # 测试 disc
                fake_samples = gen.generate(gen_seq_len)
                real_samples = dataloader.get_batch(shuffle=shuffle,
                                                    seq_length=gen_seq_len,
                                                    batch_size=batch_size,
                                                    training=False)
                disc_loss = disc.test_step(fake_samples, real_samples)

                # 测试 gen
                gen_loss = gen.test_step()

                # 得到bleu_score
                # bleu_score = get_bleu_score(true_seqs=real_samples, genned_seqs=fake_samples)
                genned_sentences = vocab.extract_seqs(fake_samples)
                # print(genned_sentences)
                # print(vocab.idx2char[fake_samples[0]])

                # 记录 test losses
                with train_summary_writer.as_default():
                    tf.summary.scalar('disc_test_loss',
                                      tf.reduce_mean(disc_loss),
                                      step=epoch)
                    tf.summary.scalar('gen_test_loss',
                                      tf.reduce_mean(gen_loss),
                                      step=epoch)
                    # tf.summary.scalar('bleu_score', tf.reduce_mean(bleu_score), step=epoch + gen_pretrain * pretrain_epochs)

            pbar.update()
Beispiel #2
0
                             pretrain_epochs=PRETRAIN_EPOCHS,
                             songs=songs,
                             char2idx=char2idx,
                             idx2char=idx2char,
                             tb_writer=train_summary_writer,
                             learning_rate=1e-4)
    print('Start pre-training generator...')
    gen_pre_trainer.pretrain(gen_seq_len, save_gen_weights)

disc = Discriminator(vocab_size=vocab_size,
                    embedding_dim=embedding_dim,
                    rnn_units=gen_rnn_units,
                    batch_size=batch_size,
                     checkpoint_dir=pretrain_checkpoint_dir)
if load_disc_weights:
    disc.load_weights()

if disc_pretrain:
    disc_pre_trainer = DiscPretrainer(discriminator=disc,
                            generator=gen,
                             pretrain_epochs=PRETRAIN_EPOCHS,
                             songs=songs,
                             char2idx=char2idx,
                             idx2char=idx2char,
                             tb_writer=train_summary_writer,
                             learning_rate=1e-4)
    print('Start pre-training discriminator...')
    disc_pre_trainer.pretrain(save_disc_weights)


Beispiel #3
0
class DCGAN(object):
    def __init__(self, input_dim, image_shape):
        self.input_dim = input_dim
        self.d = Discriminator(image_shape).get_model()
        self.g = Generator(input_dim, image_shape).get_model()

    def compile(self, g_optim, d_optim):
        self.d.trainable = False
        self.dcgan = Sequential([self.g, self.d])
        self.dcgan.compile(loss='binary_crossentropy', optimizer=g_optim)
        self.d.trainable = True
        self.d.compile(loss='binary_crossentropy', optimizer=d_optim)

    def train(self, epochs, batch_size, X_train):
        g_losses = []
        d_losses = []
        for epoch in range(epochs):
            np.random.shuffle(X_train)
            n_iter = X_train.shape[0] // batch_size
            progress_bar = Progbar(target=n_iter)
            for index in range(n_iter):
                # create random noise -> N latent vectors
                noise = np.random.uniform(-1,
                                          1,
                                          size=(batch_size, self.input_dim))

                # load real data & generate fake data
                image_batch = X_train[index * batch_size:(index + 1) *
                                      batch_size]
                for i in range(batch_size):
                    if np.random.random() > 0.5:
                        image_batch[i] = np.fliplr(image_batch[i])
                    if np.random.random() > 0.5:
                        image_batch[i] = np.flipud(image_batch[i])
                generated_images = self.g.predict(noise, verbose=0)

                # attach label for training discriminator
                X = np.concatenate((image_batch, generated_images))
                y = np.array([1] * batch_size + [0] * batch_size)

                # training discriminator
                d_loss = self.d.train_on_batch(X, y)

                # training generator
                g_loss = self.dcgan.train_on_batch(noise,
                                                   np.array([1] * batch_size))

                progress_bar.update(index,
                                    values=[('g', g_loss), ('d', d_loss)])
            g_losses.append(g_loss)
            d_losses.append(d_loss)
            if (epoch + 1) % 10 == 0:
                image = self.combine_images(generated_images)
                image = (image + 1) / 2.0 * 255.0
                cv2.imwrite('./result/' + str(epoch) + ".png", image)
            print('\nEpoch' + str(epoch) + " end")

            # save weights for each epoch
            if (epoch + 1) % 50 == 0:
                self.g.save_weights('weights/generator_' + str(epoch) + '.h5',
                                    True)
                self.d.save_weights(
                    'weights/discriminator_' + str(epoch) + '.h5', True)
        return g_losses, d_losses

    def load_weights(self, g_weight, d_weight):
        self.g.load_weights(g_weight)
        self.d.load_weights(d_weight)

    def combine_images(self, generated_images):
        num = generated_images.shape[0]
        width = int(math.sqrt(num))
        height = int(math.ceil(float(num) / width))
        shape = generated_images.shape[1:4]
        image = np.zeros((height * shape[0], width * shape[1], shape[2]),
                         dtype=generated_images.dtype)
        for index, img in enumerate(generated_images):
            i = int(index / width)
            j = index % width
            image[i * shape[0]:(i + 1) * shape[0],
                  j * shape[1]:(j + 1) * shape[1], :] = img[:, :, :]
        return image