Ejemplo n.º 1
0
def train(dataset,
          epochs=30,
          num_images=1,
          latent_dim=256,
          learning_rate_g=0.00005,
          learning_rate_d=0.00005):
    discriminator = Discriminator().discriminator
    generator = Generator(latent_dim).generator

    gan = GAN(discriminator, generator, latent_dim)
    gan.compile(
        d_optimizer=keras.optimizers.Adam(learning_rate_d),
        g_optimizer=keras.optimizers.Adam(learning_rate_g),
        loss_function=keras.losses.BinaryCrossentropy(from_logits=True),
    )

    gan.fit(dataset,
            epochs=epochs,
            callbacks=[GANMonitor(num_images, latent_dim)])

    # gan.save("gan_model.h5")
    generator.save("generator_model.h5")
    discriminator.save("discriminator_model.h5")
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    assert START_TOKEN == 0

    physical_devices = tf.config.experimental.list_physical_devices("GPU")
    if len(physical_devices) > 0:
        for dev in physical_devices:
            tf.config.experimental.set_memory_growth(dev, True)

    generator = Generator(VOCAB_SIZE, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    target_lstm = RNNLM(VOCAB_SIZE, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) 
    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=VOCAB_SIZE, embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, dropout_keep_prob=dis_dropout_keep_prob,
                                  l2_reg_lambda=dis_l2_reg_lambda)
    
    gen_dataset = dataset_for_generator(positive_file, BATCH_SIZE)
    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    if not os.path.exists("save/generator_pretrained.h5"):
        print('Start pre-training...')
        log.write('pre-training...\n')
        generator.pretrain(gen_dataset, target_lstm, PRE_EPOCH_NUM, generated_num // BATCH_SIZE, eval_file)
        generator.save("save/generator_pretrained.h5")
    else:
        generator.load("save/generator_pretrained.h5")

    if not os.path.exists("discriminator_pretrained.h5"):
        print('Start pre-training discriminator...')
        # Train 3 epoch on the generated data and do this for 50 times
        for _ in range(50):
            print("Dataset", _)
            generator.generate_samples(generated_num // BATCH_SIZE, negative_file)
            dis_dataset = dataset_for_discriminator(positive_file, negative_file, BATCH_SIZE)
            discriminator.train(dis_dataset, 3, (generated_num // BATCH_SIZE) * 2)
        discriminator.save("save/discriminator_pretrained.h5")
    else:
        discriminator.load("save/discriminator_pretrained.h5")

    rollout = ROLLOUT(generator, 0.8)

    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    
    for total_batch in range(TOTAL_BATCH):
        print("Generator", total_batch, 'of ', TOTAL_BATCH)
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate_one_batch()
            rewards = rollout.get_reward(samples, 16, discriminator)
            generator.train_step(samples, rewards)

        # Test
        if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
            generator.generate_samples(generated_num // BATCH_SIZE, eval_file)
            likelihood_dataset = dataset_for_generator(eval_file, BATCH_SIZE)
            test_loss = target_lstm.target_loss(likelihood_dataset)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print('total_batch: ', total_batch, 'of: ', TOTAL_BATCH, 'test_loss: ', test_loss)
            generator.save(f"save/generator_{total_batch}.h5")
            discriminator.save(f"save/discriminator_{total_batch}.h5")
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        print("Discriminator", total_batch, 'of ', TOTAL_BATCH)
        # There will be 5 x 3 = 15 epochs in this loop
        for _ in range(5):
            generator.generate_samples(generated_num // BATCH_SIZE, negative_file)
            dis_dataset = dataset_for_discriminator(positive_file, negative_file, BATCH_SIZE)
            discriminator.train(dis_dataset, 3, (generated_num // BATCH_SIZE) * 2)
    generator.save(f"save/generator_{TOTAL_BATCH}.h5")
    discriminator.save(f"save/discriminator_{TOTAL_BATCH}.h5")

    log.close()
Ejemplo n.º 3
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    assert START_TOKEN == 0

    vocab_size = 5000

    physical_devices = tf.config.experimental.list_physical_devices("GPU")
    if len(physical_devices) > 0:
        for dev in physical_devices:
            tf.config.experimental.set_memory_growth(dev, True)

    # 生成器の初期設定
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    # 最初のパラメータをpicklファイルから参照
    target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    target_lstm = TARGET_LSTM(BATCH_SIZE, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    # 識別器の初期設定
    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, dropout_keep_prob=dis_dropout_keep_prob,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    # GANの学習で使用する正解データを作成する
    if not os.path.exists(positive_file):
        target_lstm.generate_samples(generated_num // BATCH_SIZE, positive_file)
    gen_dataset = dataset_for_generator(positive_file, BATCH_SIZE)
    log = open('save/experiment-log.txt', 'w')


    #  事前学習での文章生成をlstmで行い、生成器の重みを保存する
    if not os.path.exists("generator_pretrained.h5"):
        print('Start pre-training...')
        log.write('pre-training...\n')
        generator.pretrain(gen_dataset, target_lstm, PRE_EPOCH_NUM, generated_num // BATCH_SIZE, eval_file)
        generator.save("generator_pretrained.h5")
    else:
        generator.load("generator_pretrained.h5")

    # 識別器の事前学習での重み
    if not os.path.exists("discriminator_pretrained.h5"):
        print('Start pre-training discriminator...')
        # Train 3 epoch on the generated data and do this for 50 times
        # 3エポックの識別器の訓練を50回繰り返す
        for _ in range(50):
            print("Dataset", _)

            # まず生成器が偽物を作成
            generator.generate_samples(generated_num // BATCH_SIZE, negative_file)

            # 偽物と本物を混ぜたデータセットを作成
            dis_dataset = dataset_for_discriminator(positive_file, negative_file, BATCH_SIZE)

            # 識別器を学習させる
            discriminator.train(dis_dataset, 3, (generated_num // BATCH_SIZE) * 2)
        discriminator.save("discriminator_pretrained.h5")
    else:
        discriminator.load("discriminator_pretrained.h5")

    rollout = ROLLOUT(generator, 0.8)

    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')

    # 学習の実行
    # 今回は200回の訓練を行う
    for total_batch in range(TOTAL_BATCH):
        print("Generator", total_batch)
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate_one_batch()
            rewards = rollout.get_reward(samples, 16, discriminator)
            generator.train_step(samples, rewards)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generator.generate_samples(generated_num // BATCH_SIZE, eval_file)
            likelihood_dataset = dataset_for_generator(eval_file, BATCH_SIZE)
            test_loss = target_lstm.target_loss(likelihood_dataset)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        print("Discriminator", total_batch)
        for _ in range(5):
            generator.generate_samples(generated_num // BATCH_SIZE, negative_file)
            dis_dataset = dataset_for_discriminator(positive_file, negative_file, BATCH_SIZE)
            discriminator.train(dis_dataset, 3, (generated_num // BATCH_SIZE) * 2)
    generator.save("generator.h5")
    discriminator.save("discriminator.h5")
    
    generator.generate_samples(generated_num // BATCH_SIZE, output_file)

    log.close()
Ejemplo n.º 4
0
from train import train
from util.cifar import cifar_dataset
from util import summaries
from config import generated_dir
from config import size_factor
import _globals

if __name__ == "__main__":

    summary_writer = tf.summary.create_file_writer(os.path.join(generated_dir, 'summaries'))
    with summary_writer.as_default():

        g = Generator(size_factor)
        d = Discriminator(size_factor)
        try:
            print("Started training")
            epochs_per_cycle = 5
            for i in range(1000):
                timestamp = f'step_{_globals.step:08d}'
                g.save(timestamp), d.save(timestamp)
                g.save('latest'), d.save('latest')

                g.sample(n=400, save_to=os.path.join(generated_dir, "output", timestamp))
                train(cifar_dataset, epochs_per_cycle, g=g, d=d)
                summaries.main()
        except KeyboardInterrupt as e:
            print(e)
        finally:
            print("Saving the model")
            g.save("latest"), d.save("latest")