tf.summary.scalar('gen_train_loss', gen_loss, step=epoch) tf.summary.scalar('disc_train_loss', disc_loss, step=epoch) tf.summary.scalar('total_train_loss', disc_loss + gen_loss, step=epoch) if epoch % 7 == 0 or epoch == 0: disc.model.save_weights(disc.checkpoint_prefix) gen.model.save_weights(disc.checkpoint_prefix) samples = gen.generate(gen_seq_len) genned_songs = extract_songs(samples) bleu_score = get_bleu_score(genned_songs) # print(idx2char[samples[0]]) gen.model.save_weights(gen.checkpoint_prefix) #test disc fake_samples = gen.generate() real_samples = get_batch(seq_len, batch_size, training=False) disc_loss = disc.test_step(fake_samples, real_samples) #test gen gen_loss = gen.test_step() #record test losses with train_summary_writer.as_default(): tf.summary.scalar('disc_test_loss', tf.reduce_mean(disc_loss), step=epoch) tf.summary.scalar('gen_test_loss', tf.reduce_mean(gen_loss), step=epoch) tf.summary.scalar('bleu_score', tf.reduce_mean(bleu_score), step=epoch + gen_pretrain*PRETRAIN_EPOCHS)
def main(pretrain_checkpoint_dir, train_summary_writer, vocab: Vocab, dataloader: DataLoader, batch_size: int = 64, embedding_dim: int = 256, seq_length: int = 3000, gen_seq_len: int = 3000, gen_rnn_units: int = 1024, disc_rnn_units: int = 1024, epochs: int = 40000, pretrain_epochs: int = 4000, learning_rate: float = 1e-4, rollout_num: int = 2, gen_pretrain: bool = False, disc_pretrain: bool = False, load_gen_weights: bool = False, load_disc_weights: bool = False, save_gen_weights: bool = True, save_disc_weights: bool = True, disc_steps: int = 3): gen = Generator(dataloader=dataloader, vocab=vocab, batch_size=batch_size, embedding_dim=embedding_dim, seq_length=seq_length, checkpoint_dir=pretrain_checkpoint_dir, rnn_units=gen_rnn_units, start_token=0, learning_rate=learning_rate) if load_gen_weights: gen.load_weights() if gen_pretrain: gen_pre_trainer = GenPretrainer(gen, dataloader=dataloader, vocab=vocab, pretrain_epochs=pretrain_epochs, tb_writer=train_summary_writer, learning_rate=learning_rate) print('Start pre-training generator...') gen_pre_trainer.pretrain(gen_seq_len=gen_seq_len, save_weights=save_gen_weights) disc = Discriminator(vocab_size=vocab.vocab_size, embedding_dim=embedding_dim, rnn_units=disc_rnn_units, batch_size=batch_size, checkpoint_dir=pretrain_checkpoint_dir, learning_rate=learning_rate) if load_disc_weights: disc.load_weights() if disc_pretrain: disc_pre_trainer = DiscPretrainer(disc, gen, dataloader=dataloader, vocab=vocab, pretrain_epochs=pretrain_epochs, tb_writer=train_summary_writer, learning_rate=learning_rate) print('Start pre-training discriminator...') disc_pre_trainer.pretrain(save_disc_weights) rollout = Rollout(generator=gen, discriminator=disc, vocab=vocab, batch_size=batch_size, seq_length=seq_length, rollout_num=rollout_num) with tqdm(desc='Epoch: ', total=epochs, dynamic_ncols=True) as pbar: for epoch in range(epochs): fake_samples = gen.generate() rewards = rollout.get_reward(samples=fake_samples) gen_loss = gen.train_step(fake_samples, rewards) real_samples, _ = dataloader.get_batch(shuffle=shuffle, seq_length=seq_length, batch_size=batch_size, training=True) disc_loss = 0 for i in range(disc_steps): disc_loss += disc.train_step(fake_samples, real_samples) / disc_steps with train_summary_writer.as_default(): tf.summary.scalar('gen_train_loss', gen_loss, step=epoch) tf.summary.scalar('disc_train_loss', disc_loss, step=epoch) tf.summary.scalar('total_train_loss', disc_loss + gen_loss, step=epoch) pbar.set_postfix(gen_train_loss=tf.reduce_mean(gen_loss), disc_train_loss=tf.reduce_mean(disc_loss), total_train_loss=tf.reduce_mean(gen_loss + disc_loss)) if (epoch + 1) % 5 == 0 or (epoch + 1) == 1: print('保存weights...') # 保存weights gen.model.save_weights(gen.checkpoint_prefix) disc.model.save_weights(disc.checkpoint_prefix) # gen.model.save('gen.h5') # disc.model.save('disc.h5') # 测试 disc fake_samples = gen.generate(gen_seq_len) real_samples = dataloader.get_batch(shuffle=shuffle, seq_length=gen_seq_len, batch_size=batch_size, training=False) disc_loss = disc.test_step(fake_samples, real_samples) # 测试 gen gen_loss = gen.test_step() # 得到bleu_score # bleu_score = get_bleu_score(true_seqs=real_samples, genned_seqs=fake_samples) genned_sentences = vocab.extract_seqs(fake_samples) # print(genned_sentences) # print(vocab.idx2char[fake_samples[0]]) # 记录 test losses with train_summary_writer.as_default(): tf.summary.scalar('disc_test_loss', tf.reduce_mean(disc_loss), step=epoch) tf.summary.scalar('gen_test_loss', tf.reduce_mean(gen_loss), step=epoch) # tf.summary.scalar('bleu_score', tf.reduce_mean(bleu_score), step=epoch + gen_pretrain * pretrain_epochs) pbar.update()