Ejemplo n.º 1
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0
    vocab_size = 5000  # why not a constant?
    log = open('save/experiment-log.txt', 'w')

    #likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = construct_generator(vocab_size)
    target_lstm = construct_gold_generator(vocab_size)
    discriminator = construct_discriminator(vocab_size)

    sess = initialize_session()

    # First, use the oracle model to provide the positive examples,
    # which are sampled from the oracle data distribution
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                     positive_file)
    gen_data_loader.create_batches(positive_file)

    #  pre-train generator
    print('Start pre-training...')
    pretrain_generator(sess, generator, gen_data_loader, target_lstm, log)

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    pretrain_discriminator(sess, discriminator, dis_data_loader, generator)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        train_generator(sess, generator, target_lstm, rollout, discriminator,
                        total_batch, log)
        rollout.update_params()
        train_discriminator(sess, discriminator, dis_data_loader, generator)
    log.close()
Ejemplo n.º 2
0
def create_model(sess, save_folder, FLAGS, embed_fn):
    # load vocab & embeddings
    with open(save_folder + "vocab.pkl", "rb") as handle:
        vocab = pickle.load(handle)
    with open(save_folder + "tsf_vocab_inv.pkl", "rb") as handle:
        tsf_vocab_inv = pickle.load(handle)
    with open(save_folder + "init_embed.pkl", "rb") as handle:
        init_embed = pickle.load(handle)
    with open(save_folder + "tsf_init_embed.pkl", "rb") as handle:
        tsf_init_embed = pickle.load(handle)
    vocab_size = len(vocab)
    tsf_vocab_size = len(tsf_vocab_inv)
    print("Vocab size: {}, transfer vocab size: {}".format(
        vocab_size, tsf_vocab_size))

    # generator
    config_list = [(k, FLAGS[k].value) for k in FLAGS]
    generator_config = OrderedDict(
        sorted(config_list) +
        [("encoder_vocab_size",
          vocab_size), ("decoder_vocab_size", tsf_vocab_size)])
    #print("Generator config: {}, cell_type: {}".format(generator_config, "gru"))
    generator = Generator(generator_config, init_embed, tsf_init_embed)

    # language model
    lm_config_list = [(k, FLAGS[k].value) for k in FLAGS if k.startswith("lm_")
                      ] + [("batch_size", FLAGS.batch_size)]
    lm_config = OrderedDict(
        sorted(lm_config_list) + [("lm_vocab_size", vocab_size)])
    rnnlm = RNNLM(lm_config, init_embed)

    # style discriminator
    style_discriminator = StyleDiscriminator(FLAGS.style_num_classes, FLAGS.embedding_dim, \
                                             init_embed, FLAGS.style_hidden_size, \
                                             FLAGS.style_attention_size, FLAGS.max_sent_len, \
                                             FLAGS.style_keep_prob)
    #embedding_size, init_embed, hidden_size, \
    #                 attention_size, max_sent_len, keep_prob):
    #siamese discriminator
    siamese_discrim = SiameseDiscriminator(FLAGS.embedding_dim, \
                                             init_embed, FLAGS.style_hidden_size, \
                                             FLAGS.style_attention_size, FLAGS.max_sent_len, \
                                             FLAGS.style_keep_prob)
    # semantic discriminator
    semantic_discriminator = SemanticDiscriminator(embed_fn)

    # rollout
    rollout = ROLLOUT(vocab, tsf_vocab_inv)

    return generator, rnnlm, style_discriminator, siamese_discrim, semantic_discriminator, rollout, vocab, tsf_vocab_inv
Ejemplo n.º 3
0
def main():
    clock = Clock()
    clock.start()
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    parser = argparse.ArgumentParser(description='conditional SeqGAN')
    parser.add_argument('--conditional',
                        '-c',
                        type=int,
                        default=0,
                        help='If you make SeqGAN conditional, set `-c` 1.')
    args = parser.parse_args()
    cond = args.conditional

    vocab = Vocab()
    vocab.construct(parsed_haiku_file)
    vocab.word2id(parsed_haiku_file, positive_file)
    UNK = vocab.dic.token2id[u'<UNK>']
    COMMA = vocab.dic.token2id[u',']

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH, COND_LENGTH, UNK)
    # likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH, COND_LENGTH, UNK) # For testing
    vocab_size = len(vocab.dic.token2id)
    with open(output_token2id, 'w') as f:
        pickle.dump(vocab.dic.token2id, f)
    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH, UNK)

    generator = Generator(vocab_size,
                          BATCH_SIZE,
                          EMB_DIM,
                          HIDDEN_DIM,
                          SEQ_LENGTH,
                          COND_LENGTH,
                          START_TOKEN,
                          is_cond=cond)
    # target_params = cPickle.load(open('save/target_params.pkl'))
    # target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  cond_length=COND_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  batch_size=BATCH_SIZE,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda,
                                  is_cond=cond)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    # generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    if cond:
        vocab.word2id(parsed_kigo_file, positive_condition_file)
        vocab.load_cond(positive_condition_file, COND_LENGTH, UNK)
        gen_data_loader.create_cond_batches(positive_condition_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_GEN_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader, cond=cond)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file, cond, vocab)
            # likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            # print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            # buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            # log.write(buffer)
    clock.check_HMS()

    print 'Start pre-training discriminator...'
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(PRE_EPOCH_DIS_NUM):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file, cond, vocab)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)
    clock.check_HMS()

    rollout = ROLLOUT(generator, 0.8, SEQ_LENGTH)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            if cond:
                cond_batch = vocab.choice_cond(BATCH_SIZE)
                samples = generator.generate(sess, cond=cond_batch)
                rewards = rollout.get_reward(sess,
                                             samples,
                                             16,
                                             discriminator,
                                             cond=cond_batch)
            else:
                samples = generator.generate(sess)
                rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            if cond:
                feed[generator.cond] = cond_batch
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file, cond, vocab)
            # likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            # buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            # print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            # log.write(buffer)
            if total_batch % 20 == 0 or total_batch == TOTAL_BATCH - 1:
                if cond:
                    vocab.id2word(
                        eval_file,
                        generated_haiku_with_kigo_file.format(total_batch))
                else:
                    vocab.id2word(eval_file,
                                  generated_haiku_file.format(total_batch))

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file, cond, vocab)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
    clock.check_HMS()
    saver = tf.train.Saver()
    saver.save(sess, output_generator)
    log.close()
Ejemplo n.º 4
0
    generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file, word_embedding_matrix, random_type)
    dis_data_loader.load_train_data(positive_file, negative_file)
    for pd in range(3): # 3
        print("dis_pretrain: ", pd)
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch): # 빨리 돌리려면 여기를 1로
            seq_batch, condition_batch, label_batch = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: seq_batch,
                discriminator.input_cond: condition_batch,
                discriminator.input_y: label_batch,
                discriminator.dropout_keep_prob: dis_dropout_keep_prob
            }
            _ = sess.run(discriminator.train_op, feed)

rollout = ROLLOUT(generator, 0.8, word_embedding_matrix)

print('#########################################################################')
print('Start Adversarial Training...')
gen_sample.write('adversarial training...\n')
for total_batch in range(TOTAL_BATCH):
    # Train the generator for one step
    for it in range(1):
        random_type = np.random.randint(0, TYPE_SIZE, BATCH_SIZE)
        samples = generator.generate(sess, word_embedding_matrix, random_type)
        rewards = rollout.get_reward(sess, samples, 16, discriminator, random_type)
        feed = {generator.x: samples, generator.rewards: rewards, generator.type_index: random_type,
                generator.word_embedding_matrix: word_embedding_matrix}
        _ = sess.run(generator.g_updates, feed_dict=feed)

    # Test
Ejemplo n.º 5
0
def main(source_file, wordVocab, vocab_size):
    tf.reset_default_graph()
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing

    # todo:  print ("starting generating positive samples...")
    generated_num = gen_data_loader.transform_positive_file_2(
        train_dir + source_file, train_dir + positive_file, wordVocab,
        SEQ_LENGTH)
    print("generated_num: ", generated_num)
    if generated_num < 100: return
    gen_data_loader.create_batches(train_dir + positive_file)

    with tf.variable_scope("Train", reuse=None):
        generator = Generator(wordVocab,
                              vocab_size,
                              BATCH_SIZE,
                              EMB_DIM,
                              HIDDEN_DIM,
                              SEQ_LENGTH,
                              START_TOKEN,
                              learning_rate=g_lrn)

        discriminator = Discriminator(word_vocab=wordVocab,
                                      sequence_length=SEQ_LENGTH,
                                      num_classes=2,
                                      embedding_size=dis_embedding_dim,
                                      filter_sizes=dis_filter_sizes,
                                      num_filters=dis_num_filters,
                                      l2_reg_lambda=dis_l2_reg_lambda,
                                      learning_rate=d_lrn)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # todo:  1.##############pre-train generator##############
    print 'Start pre-training generator with MLE...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM_generator):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            buffer = 'epoch:\t' + str(epoch) + '\tloss:\t' + str(loss)
            print(buffer)
            sys.stdout.flush()
            log.write(buffer)
            # generate_samples(sess,
            #                  generator,
            #                  BATCH_SIZE,
            #                  generated_num,
            #                  eval_file)
            # likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            # print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            # buffer = 'epoch:\t' + str(epoch) + '\tnllscore:\t' + str(test_loss) + '\n'
            # log.write(buffer)

    # todo:  2.##############pre-train discriminator##############
    print 'Start pre-training discriminator...'
    for _ in range(PRE_EPOCH_NUM_discriminator):
        ## 由于是对概率分布的采样,所以每次生成的fake data数据都是不同的
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):  ## 对每批fake_data进行训练discriminator
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    with tf.variable_scope("Train", reuse=None):
        g_beta = ROLLOUT(generator, 0.8)  ## 这是表示 g_beta

    # todo:  3.############## Adversarial Training ##############
    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # todo: Train the generator for one batch samples
        for it in range(1):
            samples = generator.generate(sess)
            rewards = g_beta.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss],
                                 feed_dict=feed)

        # Test
        if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
            buffer = 'epoch:\t' + str(total_batch) + '\tg_loss:\t' + str(
                g_loss)
            print(buffer)
            sys.stdout.flush()
            log.write(buffer)

        g_beta.update_params()

        # todo: Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

        if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
            out_file = out_negative_file + str(total_batch) + ".txt"
            transform_file(negative_file, wordVocab, out_file)

    generate_samples(sess, generator, BATCH_SIZE, need_generated_samples,
                     negative_file)
    transform_file(negative_file, wordVocab, source_file + ".GEN")
Ejemplo n.º 6
0
def seqgan(pos_file_a, pos_file_b):
    print('Init Variable ###########################################')
    # random.seed(SEED)
    # np.random.seed(SEED)
    assert START_TOKEN == 0

    positive_file_a, negative_file_a, output_music_gan_a, output_music_mle_a = init_var(
        pos_file_a)
    gen_data_loader_a, dis_data_loader_a = init_data_loader(positive_file_a)
    generator_a = gen('a')
    discriminator_a = dis('a')

    positive_file_b, negative_file_b, output_music_gan_b, output_music_mle_b = init_var(
        pos_file_b)
    gen_data_loader_b, dis_data_loader_b = init_data_loader(positive_file_b)
    generator_b = gen('b')
    discriminator_b = dis('b')

    negative_file_f = 'tmp_f'
    dis_data_loader_f = F_Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)
    gen_data_loader_f = F_Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    gen_data_loader_f.create_batches(positive_file_a, positive_file_b)
    generator_f = gen('f')
    discriminator_f = dis('f', num_class=3)

    dis_data_loader_ab = AB_Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

    print('Init TensorFlow ###########################################')
    # init TensorFlow Session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # for tensorboard, debug
    tb_write = tf.summary.FileWriter(logdir)
    tb_write.add_graph(sess.graph)
    mio = MIDI_IO()

    print(
        '#########################################################################'
    )
    with tf.name_scope('pretrain-G-A'):
        print('Start pre-training generator A')
        pre_train_epoch(sess, generator_a, gen_data_loader_a, tb_write,
                        PRE_GEN_EPOCH)

        # with tf.name_scope('pretrain-D-A'):
        print('Start pre-training discriminator A')
        train_d(sess, dis_data_loader_a, positive_file_a, negative_file_a,
                generator_a, discriminator_a, tb_write, PRE_DIS_EPOCH * 3)

        # with tf.name_scope('rollout-A'):
        rollout_a = ROLLOUT(generator_a, 0.8, SEQ_LENGTH)

    # sample sequence after MLE
    generate_samples(sess, generator_a, BATCH_SIZE, gen_len,
                     output_music_mle_a)

    print(
        '-------------------------------------------------------------------------'
    )
    print('Start Adversarial Training A')
    with tf.name_scope('GANtrain-A'):
        for total_batch in range(GAN_OUTER_EPOCH):
            print('Adversarial Training Progress' + str(total_batch))

            # Train the generator
            gan_g(sess, generator_a, discriminator_a, rollout_a, tb_write,
                  GAN_G_EPOCH)

            # Update roll-out parameters
            rollout_a.update_params()

            # Train the discriminator
            train_d(sess, dis_data_loader_a, positive_file_a, negative_file_a,
                    generator_a, discriminator_a, tb_write, GAN_D_EPOCH)

    generate_samples(sess, generator_a, BATCH_SIZE, gen_len,
                     output_music_gan_a)

    print(
        '#########################################################################'
    )
    print(
        '#########################################################################'
    )
    with tf.name_scope('pretrain-G-B'):
        print('Start pre-training generator B')
        pre_train_epoch(sess, generator_b, gen_data_loader_b, tb_write,
                        PRE_GEN_EPOCH)

        # with tf.name_scope('pretrain-D-B'):
        print('Start pre-training discriminator B')
        train_d(sess, dis_data_loader_b, positive_file_b, negative_file_b,
                generator_b, discriminator_b, tb_write, PRE_DIS_EPOCH)

        # with tf.name_scope('rollout-B'):
        rollout_b = ROLLOUT(generator_b, 0.8, SEQ_LENGTH)

    # sample sequence after MLE
    generate_samples(sess, generator_b, BATCH_SIZE, gen_len,
                     output_music_mle_b)

    print(
        '-------------------------------------------------------------------------'
    )
    print('Start Adversarial Training B')
    with tf.name_scope('GANtrain-B'):
        for total_batch in range(GAN_OUTER_EPOCH):
            print('Adversarial Training Progress' + str(total_batch))

            # Train the generator
            gan_g(sess, generator_b, discriminator_b, rollout_b, tb_write,
                  GAN_G_EPOCH)

            # Update roll-out parameters
            rollout_b.update_params()

            # Train the discriminator
            train_d(sess, dis_data_loader_b, positive_file_b, negative_file_b,
                    generator_b, discriminator_b, tb_write, GAN_D_EPOCH)

    # generate_samples(sess, generator_b, BATCH_SIZE, gen_len, output_music_gan_b)

    print(
        '#########################################################################'
    )
    print(
        '#########################################################################'
    )
    with tf.name_scope('pretrain-G-F'):
        print('Start pre-training generator F')
        pre_train_epoch(sess, generator_f, gen_data_loader_f, tb_write,
                        PRE_GEN_EPOCH)

        # with tf.name_scope('pretrain-D-F'):
        print('Start pre-training discriminator F')
        f_train_d(sess, dis_data_loader_f, positive_file_a, positive_file_b,
                  negative_file_f, generator_f, discriminator_f, tb_write,
                  PRE_DIS_EPOCH)

        # with tf.name_scope('rollout-B'):
        rollout_f = ROLLOUT(generator_f, 0.8, SEQ_LENGTH)

    # sample sequence after MLE
    generate_samples(sess, generator_f, BATCH_SIZE, gen_len, 'pretrain-f')
    mio.trans_generated_to_midi('pretrain-f')

    # print '-------------------------------------------------------------------------'
    # print 'Start Adversarial Training F'
    # with tf.name_scope('GANtrain-F'):
    #     for total_batch in range(GAN_OUTER_EPOCH):
    #         print 'Adversarial Training Progress', total_batch
    #
    #         # Train the generator
    #         gan_g(sess, generator_f, discriminator_f, rollout_f, tb_write, GAN_G_EPOCH)
    #
    #         # Update roll-out parameters
    #         rollout_f.update_params()
    #
    #         # Train the discriminator
    #         train_d(sess, dis_data_loader_f, positive_file_f, negative_file_f, generator_f, discriminator_f,
    #                 tb_write, GAN_D_EPOCH)

    for fusion_total_batch in range(FUSION_EPOCH):
        print(
            '#########################################################################'
        )
        print('Start Fusion GAN' + str(fusion_total_batch))
        print(
            '-------------------------------------------------------------------------'
        )
        print('Start Fusion GAN Training F')
        with tf.name_scope('Fusion-A-B'):
            for total_batch in range(FUSION_F_EPOCH):
                fusion_g(sess, generator_f, discriminator_a, discriminator_b,
                         discriminator_f, rollout_f, tb_write, FUSION_G_EPOCH)
                rollout_f.update_params()
                f_fusion_d(sess, dis_data_loader_f, positive_file_a,
                           positive_file_b, negative_file_f, generator_f,
                           discriminator_f, tb_write, FUSION_D_EPOCH)

        generate_samples(sess, generator_f, BATCH_SIZE, gen_len,
                         'fusion_' + str(fusion_total_batch))
        mio.trans_generated_to_midi('fusion_' + str(fusion_total_batch))
        # at last iteration, A and B do not need training
        print('++++++')
        print('CHK PNT' + str(fusion_total_batch) + " " + str(FUSION_EPOCH))
        print('++++++')
        if fusion_total_batch == FUSION_EPOCH - 1:
            break
        print(
            '-------------------------------------------------------------------------'
        )
        print('Start Fusion GAN Training A')
        with tf.name_scope('GANtrain-A'):
            for total_batch in range(FUSION_AB_EPOCH):
                print('Adversarial Training Progress' + str(total_batch))

                # Train the generator
                fusion_g(sess, generator_a, discriminator_a, discriminator_b,
                         discriminator_f, rollout_a, tb_write, FUSION_G_EPOCH)

                # Update roll-out parameters
                rollout_a.update_params()

                # Train the discriminator
                ab_fusion_d(sess, dis_data_loader_ab, positive_file_a,
                            positive_file_b, negative_file_a, generator_a,
                            discriminator_a, generator_f, tb_write,
                            GAN_D_EPOCH)

        print(
            '-------------------------------------------------------------------------'
        )
        print('Start Fusion GAN Training B')
        with tf.name_scope('GANtrain-A'):
            for total_batch in range(FUSION_AB_EPOCH):
                print('Adversarial Training Progress' + str(total_batch))

                # Train the generator
                fusion_g(sess, generator_b, discriminator_a, discriminator_b,
                         discriminator_f, rollout_b, tb_write, FUSION_G_EPOCH)

                # Update roll-out parameters
                rollout_b.update_params()

                # Train the discriminator
                ab_fusion_d(sess, dis_data_loader_ab, positive_file_b,
                            positive_file_a, negative_file_b, generator_b,
                            discriminator_b, generator_f, tb_write,
                            GAN_D_EPOCH)

    output_music_fusion = 'fusion_gan{}'.format(suffix)
    generate_samples(sess, generator_f, BATCH_SIZE, gen_len,
                     output_music_fusion)

    mio.trans_generated_to_midi(output_music_fusion)
Ejemplo n.º 7
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    #
    # Declare data loader
    # ----------------------------------------------------------------------------
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    # ----------------------------------------------------------------------------


    #
    # Declare Generator & Discriminator
    # ----------------------------------------------------------------------------
    # declare: generator
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    # declare: discriminator
    discriminator = Discriminator(sequence_length=20, num_classes=2,
                                   vocab_size=vocab_size, embedding_size=dis_embedding_dim,
                                   filter_sizes=dis_filter_sizes, num_filters=dis_num_filters,
                                   l2_reg_lambda=dis_l2_reg_lambda)
    # ----------------------------------------------------------------------------

    #
    # Set the session <sess>
    # ----------------------------------------------------------------------------
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    # ----------------------------------------------------------------------------

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    # generate samples by using <target_lstm> and write the samples to file <positive_file>
    #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')


    #
    # Pre-train <generator> by using <gen_data_loader>,
    # and then compute the <test_loss> of <target_lstm> and <likelihood_data_loader>
    # ----------------------------------------------------------------------------
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            # generate samples by using <generator> and write the samples to file <eval_file>
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)

            # load samples from file <eval_file>
            likelihood_data_loader.create_batches(eval_file)

            # compute <test_loss> of <target_lstm>, with input <likelihood_data_loader>
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)

            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)
    # ----------------------------------------------------------------------------


    #
    # Pre-train <discriminator> by using <generator>
    # ----------------------------------------------------------------------------
    print('Start pre-training discriminator...')
    # Generate data and train 3 epoch on the generated data, which will be done for 50 times
    for _ in range(50):
        # generate samples by using <generator> and write the samples to file <negative_file>
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

        # load samples from file <negative_file>
        dis_data_loader.load_train_data(positive_file, negative_file)

        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob}
                _ = sess.run(discriminator.train_op, feed_dict=feed)
    # ----------------------------------------------------------------------------

    rollout = ROLLOUT(generator, 0.8)

    #
    # Start seqGAN, train <discriminator> and <generator>
    # ----------------------------------------------------------------------------
    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):

        # ----- Train the generator for one step -----------------
        for it in range(G_STEPS):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, ROLLOUT_NUM, discriminator, SEQ_LENGTH)
            feed = {generator.x: samples,
                    generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)
        # --------------------------------------------------------

        # Update roll-out parameters
        rollout.update_params()

        # ----- Train the discriminator -------------------------
        for _ in range(D_STEPS):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {discriminator.input_x: x_batch,
                            discriminator.input_y: y_batch,
                            discriminator.dropout_keep_prob: dis_dropout_keep_prob}
                    _ = sess.run(discriminator.train_op, feed_dict=feed)
        # --------------------------------------------------------
    # ----------------------------------------------------------------------------

    log.close()
Ejemplo n.º 8
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    # assert START_TOKEN == 0

    vocab_size = NUM_EMB
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM,
                          HIDDEN_DIM, MAX_LENGTH, START_TOKEN)
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE,
                              EMB_DIM, HIDDEN_DIM, MAX_LENGTH, 0)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(
            sequence_length=MAX_LENGTH,
            num_classes=2,
            vocab_size=vocab_size,
            embedding_size=dis_embedding_dim,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [param for param in tf.trainable_variables()
                  if 'discriminator' in param.name]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(
        cnn.loss, cnn_params, aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(
        dis_grads_and_vars, global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    def train_discriminator():
        if D_WEIGHT == 0:
            return 0, 0

        negative_samples = generate_samples(
            sess, generator, BATCH_SIZE, POSITIVE_NUM)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_samples, negative_samples)
        dis_batches = dis_data_loader.batch_iter(
            zip(dis_x_train, dis_y_train), dis_batch_size, dis_num_epochs
        )

        for batch in dis_batches:
            x_batch, y_batch = zip(*batch)
            feed = {
                cnn.input_x: x_batch,
                cnn.input_y: y_batch,
                cnn.dropout_keep_prob: dis_dropout_keep_prob
            }
            _, step, loss, accuracy = sess.run(
                [dis_train_op, dis_global_step, cnn.loss, cnn.accuracy], feed)
        print('\tD loss  :   {}'.format(loss))
        print('\tAccuracy: {}'.format(accuracy))
        return loss, accuracy

    # Pretrain is checkpointed and only execcutes if we don't find a checkpoint
    saver = tf.train.Saver()
    ckpt_dir = 'checkpoints/{}_pretrain'.format(PREFIX)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    ckpt_file = os.path.join(ckpt_dir, 'pretrain_ckpt')
    if os.path.isfile(ckpt_file + '.meta') and params["LOAD_PRETRAIN"]:
        saver.restore(sess, ckpt_file)
        print('Pretrain loaded from previous checkpoint {}'.format(ckpt_file))
    else:
        sess.run(tf.global_variables_initializer())
        pretrain(sess, generator, target_lstm, train_discriminator)
        path = saver.save(sess, ckpt_file)
        print('Pretrain finished and saved at {}'.format(path))

    # create reward function
    batch_reward = make_reward(train_samples)

    rollout = ROLLOUT(generator, 0.8)

    print('#########################################################################')
    print('Start Reinforcement Training Generator...')
    results_rows = []
    for nbatch in range(TOTAL_BATCH):
        results = OrderedDict({'exp_name': PREFIX})
        if nbatch % 1 == 0 or nbatch == TOTAL_BATCH - 1:
            print('* Making samples')
            if nbatch % 10 == 0:
                gen_samples = generate_samples(
                    sess, generator, BATCH_SIZE, BIG_SAMPLE_NUM)
            else:
                gen_samples = generate_samples(
                    sess, generator, BATCH_SIZE, SAMPLE_NUM)
            likelihood_data_loader.create_batches(gen_samples)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('batch_num: {}'.format(nbatch))
            print('test_loss: {}'.format(test_loss))
            results['Batch'] = nbatch
            results['test_loss'] = test_loss

            if test_loss < best_score:
                best_score = test_loss
                print('best score: %f' % test_loss)

            # results
            mm.compute_results(gen_samples, train_samples, ord_dict, results)

        print('#########################################################################')
        print('-> Training generator with RL.')
        print('G Epoch {}'.format(nbatch))

        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(
                sess, samples, 16, cnn, batch_reward, D_WEIGHT)
            print('Rewards be like...')
            print(rewards)
            nll = generator.generator_step(sess, samples, rewards)

            print('neg-loglike: {}'.format(nll))
            results['neg-loglike'] = nll
        rollout.update_params()

        # generate for discriminator
        print('-> Training Discriminator')
        for i in range(D):
            print('D_Epoch {}'.format(i))
            d_loss, accuracy = train_discriminator()
            results['D_loss_{}'.format(i)] = d_loss
            results['Accuracy_{}'.format(i)] = accuracy
        print('results')
        results_rows.append(results)
        if nbatch % params["EPOCH_SAVES"] == 0:
            save_results(sess, PREFIX, PREFIX + '_model', results_rows)

    # write results
    save_results(sess, PREFIX, PREFIX + '_model', results_rows)

    print('\n:*** FINISHED ***')
    return
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)
    assert START_TOKEN == 0

    physical_devices = tf.config.experimental.list_physical_devices("GPU")
    if len(physical_devices) > 0:
        for dev in physical_devices:
            tf.config.experimental.set_memory_growth(dev, True)

    generator = Generator(VOCAB_SIZE, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    target_lstm = RNNLM(VOCAB_SIZE, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN) 
    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=VOCAB_SIZE, embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, dropout_keep_prob=dis_dropout_keep_prob,
                                  l2_reg_lambda=dis_l2_reg_lambda)
    
    gen_dataset = dataset_for_generator(positive_file, BATCH_SIZE)
    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    if not os.path.exists("save/generator_pretrained.h5"):
        print('Start pre-training...')
        log.write('pre-training...\n')
        generator.pretrain(gen_dataset, target_lstm, PRE_EPOCH_NUM, generated_num // BATCH_SIZE, eval_file)
        generator.save("save/generator_pretrained.h5")
    else:
        generator.load("save/generator_pretrained.h5")

    if not os.path.exists("discriminator_pretrained.h5"):
        print('Start pre-training discriminator...')
        # Train 3 epoch on the generated data and do this for 50 times
        for _ in range(50):
            print("Dataset", _)
            generator.generate_samples(generated_num // BATCH_SIZE, negative_file)
            dis_dataset = dataset_for_discriminator(positive_file, negative_file, BATCH_SIZE)
            discriminator.train(dis_dataset, 3, (generated_num // BATCH_SIZE) * 2)
        discriminator.save("save/discriminator_pretrained.h5")
    else:
        discriminator.load("save/discriminator_pretrained.h5")

    rollout = ROLLOUT(generator, 0.8)

    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    
    for total_batch in range(TOTAL_BATCH):
        print("Generator", total_batch, 'of ', TOTAL_BATCH)
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate_one_batch()
            rewards = rollout.get_reward(samples, 16, discriminator)
            generator.train_step(samples, rewards)

        # Test
        if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
            generator.generate_samples(generated_num // BATCH_SIZE, eval_file)
            likelihood_dataset = dataset_for_generator(eval_file, BATCH_SIZE)
            test_loss = target_lstm.target_loss(likelihood_dataset)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print('total_batch: ', total_batch, 'of: ', TOTAL_BATCH, 'test_loss: ', test_loss)
            generator.save(f"save/generator_{total_batch}.h5")
            discriminator.save(f"save/discriminator_{total_batch}.h5")
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        print("Discriminator", total_batch, 'of ', TOTAL_BATCH)
        # There will be 5 x 3 = 15 epochs in this loop
        for _ in range(5):
            generator.generate_samples(generated_num // BATCH_SIZE, negative_file)
            dis_dataset = dataset_for_discriminator(positive_file, negative_file, BATCH_SIZE)
            discriminator.train(dis_dataset, 3, (generated_num // BATCH_SIZE) * 2)
    generator.save(f"save/generator_{TOTAL_BATCH}.h5")
    discriminator.save(f"save/discriminator_{TOTAL_BATCH}.h5")

    log.close()
Ejemplo n.º 10
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    if os.path.exists(DICO_PKL):
        with open(DICO_PKL, 'rb') as f:
            word_to_id, id_to_word = pickle.load(f)
    else:
        word_to_id, id_to_word = create_dico(DICO)
        with open(DICO_PKL, 'wb') as f:
            pickle.dump([word_to_id, id_to_word], f)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, word_to_id)
    dis_data_loader = Dis_Data_loader(BATCH_SIZE, word_to_id)
    vocab_size = len(word_to_id)
    assert START_TOKEN == word_to_id['sos']

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    discriminator = BLEUCNN(SEQ_LENGTH, 2, EMB_DIM, generator)
    mobilenet = MobileNet(BATCH_SIZE)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    mobilenet.load_pretrained_weights(sess)
    sess.run(tf.global_variables_initializer())

    log = open('experiment-log.txt', 'w', encoding='utf-8')
    #  pre-train generator and discriminator
    log.write('pre-training...\n')
    print('Start pre-training discriminator...')
    datas = create_data(DICO, word_to_id)
    gen_data_loader.create_batches(CORPUS, IMAGE)
    samples = []
    for it in range(gen_data_loader.num_batch):
        inp_batch, image_batch = gen_data_loader.next_batch()
        feed_dict = {mobilenet.X: image_batch, mobilenet.is_training: False}
        hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
        samples.extend(generator.generate(sess, hidden_batch).tolist())
    dis_data_loader.create_batches(random.sample(datas, 3000), samples)
    for _ in range(PRE_EPOCH_NUM):
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch):
            x_batch, labels = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.labels: labels,
                discriminator.dropout_keep_prob: 0.75
            }
            _ = sess.run(discriminator.train_op, feed)

    print('Start pre-training generator...')
    for epoch in range(PRE_EPOCH_NUM):
        supervised_g_losses = []
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            inp_batch, image_batch = gen_data_loader.next_batch()
            feed_dict = {
                mobilenet.X: image_batch,
                mobilenet.is_training: False
            }
            hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
            _, g_loss = generator.pretrain_step(sess, inp_batch, hidden_batch)
            supervised_g_losses.append(g_loss)
        loss = np.mean(supervised_g_losses)
        if epoch % 5 == 0:
            print('pre-train epoch ', epoch, 'train_loss ', loss)
            buffer = 'epoch:\t' + str(epoch) + '\ttrain_loss:\t' + str(
                loss) + '\n'
            log.write(buffer)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start REINFORCE Training...')
    log.write('REINFORCE training...\n')
    for total_batch in range(RL_EPOCH_NUM):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            ra = random.randint(0, 1)
            inp_batch, image_batch = gen_data_loader.next_batch(shuffle=ra)
            feed_dict = {
                mobilenet.X: image_batch,
                mobilenet.is_training: False
            }
            hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
            samples = generator.generate(sess, hidden_batch)
            rewards = rollout.get_reward(sess, samples, hidden_batch, 16,
                                         discriminator)
            feed = {
                generator.x: inp_batch,
                generator.rewards: rewards,
                generator.hiddens: hidden_batch
            }
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == RL_EPOCH_NUM - 1:
            mean_rewards = []
            gen_data_loader.reset_pointer()
            for it in range(gen_data_loader.num_batch):
                inp_batch, image_batch = gen_data_loader.next_batch()
                feed_dict = {
                    mobilenet.X: image_batch,
                    mobilenet.is_training: False
                }
                hidden_batch = sess.run(mobilenet.y_output,
                                        feed_dict=feed_dict)
                samples = generator.generate(sess, hidden_batch)
                rewards = rollout.get_reward(sess, samples, hidden_batch, 16,
                                             discriminator)
                mean_rewards.append(np.mean(rewards[:, -1]))
            reward = np.mean(mean_rewards)
            buffer = 'epoch:\t' + str(total_batch) + '\treward:\t' + str(
                reward) + '\n'
            print('total_batch: ', total_batch, 'reward: ', reward)
            log.write(buffer)
            generator.save_weight(sess)

        # Update roll-out parameters
        rollout.update_params()
        discriminator.update_embedding()

        # Train the discriminator
        samples = []
        for it in range(gen_data_loader.num_batch):
            inp_batch, image_batch = gen_data_loader.next_batch()
            feed_dict = {
                mobilenet.X: image_batch,
                mobilenet.is_training: False
            }
            hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
            samples.extend(generator.generate(sess, hidden_batch).tolist())
        dis_data_loader.create_batches(random.sample(datas, 3000), samples)
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch):
            x_batch, labels = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.labels: labels,
                discriminator.dropout_keep_prob: 0.75
            }
            _ = sess.run(discriminator.train_op, feed)

    # final test
    gen_data_loader.reset_pointer()
    _, image_batch = gen_data_loader.next_batch()
    feed_dict = {mobilenet.X: image_batch, mobilenet.is_training: False}
    hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
    samples = generator.generate(sess, hidden_batch)
    y = samples.tolist()
    sams = []
    for k, sam in enumerate(y):
        sa = [id_to_word[i] for i in sam]
        sa = ''.join(sa)
        sams.append(sa)
    for sam in sams:
        log.write(sam + '\n')
    log.close()
Ejemplo n.º 11
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size,
                          condition_size,
                          FEATURE_NUM,
                          BATCH_SIZE,
                          EMB_DIM,
                          COND_DIM,
                          HIDDEN_DIM,
                          Z_DIM,
                          SEQ_LENGTH,
                          START_TOKEN,
                          vocab_file,
                          condition_file,
                          word_vec=word_vec)

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Checkpoint
    saver = tf.train.Saver()
    ckpt = get_ckpt(ckpt_dir)
    if ckpt is not None:
        print("Load checkpoints from: ", ckpt)
        saver.restore(sess, ckpt)

    # Load true data
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')

    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        g_loss, lstm_loss, recon_loss, kl_loss = pre_train_epoch(
            sess, generator, gen_data_loader)
        if epoch % 10 == 0:
            log.write(
                'pre-train epoch %d, g_loss: %f, lstm_loss: %f, recon_loss: %f, kl_loss: %f\n'
                % (epoch, g_loss, lstm_loss, recon_loss, kl_loss))
            print(
                'pre-train epoch %d, g_loss: %f, lstm_loss: %f, recon_loss: %f, kl_loss: %f'
                % (epoch, g_loss, lstm_loss, recon_loss, kl_loss))
            generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                             generated_num, eval_file)

            if epoch % 20 == 0:
                saver.save(sess,
                           os.path.join(ckpt_dir, 'checkpoint_' + str(epoch)))

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                         generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess, gen_data_loader.next_batch())
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # # Test
        # if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
        #     generate_samples(sess, generator, gen_data_loader, BATCH_SIZE, generated_num, eval_file)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                             generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
Ejemplo n.º 12
0
tf.reset_default_graph()

random.seed(SEED)
np.random.seed(SEED)

gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
vocab_size = len(vocab_to_int)  # 6448 (pos)
print(vocab_size)
dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)

discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
                              filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)
rollout = ROLLOUT(generator, 0.8, word_embedding_matrix)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

sess = tf.Session(config=config)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())

print('#########################################################################')
print('Restore Trained Seqgan parameters...')
saver.restore(sess, load_model_path)
print("Model restored.")

######################################## TF-IDF #############################################
from tfidf_extract import TFIDF
Ejemplo n.º 13
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    stringGenerator = TextGenerator('../corpus/index2word.pickle',
                                    '../corpus/word2index.pickle',
                                    '../corpus/all.code')

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = len(stringGenerator.index2Word)
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = get_trainable_model(vocab_size)
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_params[00] = np.random.rand(vocab_size, 32).astype(np.float32)
    target_params[-2] = np.random.rand(32, vocab_size).astype(np.float32)
    target_params[-1] = np.random.rand(vocab_size).astype(np.float32)
    target_lstm = TARGET_LSTM(vocab_size, 64, 32, 32, 20, 0, target_params)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(sequence_length=20,
                      num_classes=2,
                      vocab_size=vocab_size,
                      embedding_size=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [
        param for param in tf.trainable_variables()
        if 'discriminator' in param.name
    ]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss,
                                                         cnn_params,
                                                         aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars,
                                                 global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    #generate_samples(sess, target_lstm, 64, 10000, positive_file)
    stringGenerator.saveSamplesToFile(20, 10000, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('log/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    significance_test(sess, target_lstm, likelihood_data_loader,
                      'significance/supervise.txt')

    print 'Start training discriminator...'
    for _ in range(dis_alter_epoch):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train),
                                                 dis_batch_size,
                                                 dis_num_epochs)

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except ValueError:
                pass

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss],
                                 feed_dict=feed)

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

            if test_loss < best_score:
                best_score = test_loss
                print 'best score: ', test_loss
                significance_test(sess, target_lstm, likelihood_data_loader,
                                  'significance/seqgan.txt')

        rollout.update_params()

        # generate for discriminator
        print 'Start training discriminator'
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(
                positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(
                zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

    log.close()
Ejemplo n.º 14
0
def main():

    starttime = datetime.datetime.now()
    short = {}
    # graph = read_graph_edgelist(network_file)
    graph = read_graph_adjlist(network_file)
    graph_nx = nx.from_dict_of_lists(graph)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)
    generated_num = np.loadtxt(positive_file).shape[0]
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    saver = tf.train.Saver(max_to_keep=12)
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, generate true random walk path
    generate_samples_true(sess, graph, BATCH_SIZE, generated_num,
                          positive_file, SEQ_LENGTH)
    gen_data_loader.create_batches(positive_file)
    #  pre-train generator
    print('Start pre-training...')
    for epoch in range(15):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        endtime = datetime.datetime.now()
        print('pre-train epoch:', epoch, ' test_loss ', loss, ' time:',
              (endtime - starttime).seconds)

    #  pre-train discriminator
    print('Start pre-training discriminator...')
    for _ in range(3):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            d_loss_his = []
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, d_loss = sess.run(
                    [discriminator.train_op, discriminator.loss], feed)
                d_loss_his.append(d_loss)
        endtime = datetime.datetime.now()
        loss, length = dis_ypred_for_auc(sess, discriminator, generator, short,
                                         graph_nx)
        print('discriminator loss: ', np.mean(d_loss_his), 'ypred:', loss,
              'length', length, 'time: ', (endtime - starttime).seconds)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    for total_batch in range(TOTAL_BATCH):
        if total_batch % 5 == 0:
            saver.save(sess, 'log/train.checkpoint', global_step=total_batch)
        # Train the generator for one step
        for it in range(20):
            generator.graph = graph
            start_token = np.random.randint(vocab_size, size=[BATCH_SIZE])
            samples = generator.generate(sess, start_token)
            rewards, sample_temp = rollout.get_reward(sess, samples, 16,
                                                      discriminator)
            feed = {
                generator.x: samples,
                generator.rewards: rewards,
                generator.start_token: start_token
            }
            _, gen_loss = sess.run([generator.g_updates, generator.g_loss],
                                   feed_dict=feed)
            generator.graph = None
            loss, length = dis_ypred_for_auc(sess, discriminator, generator,
                                             short, graph_nx)
            endtime = datetime.datetime.now()

            print('before total_batch: ', total_batch, 'reward: ', loss,
                  'length:', length, 'test_loss: ', gen_loss, 'time: ',
                  (endtime - starttime).seconds)
        rollout.update_params()
        loss, length = dis_ypred_for_auc(sess, discriminator, generator, short,
                                         graph_nx)
        print('after total_batch: ', total_batch, 'reward: ', loss, 'length',
              length, 'time: ', (endtime - starttime).seconds)
        # Train the discriminator
        for _ in range(3):
            generator.graph = None
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            d_loss_his = []
            for _ in range(1):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, d_loss = sess.run(
                        [discriminator.train_op, discriminator.loss], feed)
                    d_loss_his.append(d_loss)
            endtime = datetime.datetime.now()
            loss, length = dis_ypred_for_auc(sess, discriminator, generator,
                                             short, graph_nx)
            print('discriminator loss: ', np.mean(d_loss_his), 'ypred:', loss,
                  'length', length, 'time: ', (endtime - starttime).seconds)

        if total_batch % 5 == 0:
            saver.save(sess, 'log/train.checkpoint', global_step=epoch)
    saver.save(sess, 'log/train.checkpoint', global_step=epoch + 1)
    #generare final fake path
    generate_samples(sess, generator, BATCH_SIZE, generated_num,
                     'save/final.txt')
    command = 'deepwalk --input example_graphs/karate.adjlist --output karate.embeddings --extra save/final.txt'
    subprocess.call(command, shell=True)
Ejemplo n.º 15
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                              SEQ_LENGTH, START_TOKEN,
                              target_params)  # The oracle model

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    # generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print 'Start pre-training discriminator...'
    # Train 3 epoch on the generated data and do this for 50 times
    for epoch in range(50):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        if epoch % 5 == 0:
            print 'pre-train discriminator epoch ', epoch
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
Ejemplo n.º 16
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=num_classes, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                                filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)
    #target_params = cPickle.load(open('save/target_params.pkl'))
    #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    # avoid occupy all the memory if the GPU
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    #Savers 
    saver_gen = tf.train.Saver()
    saver_dis = tf.train.Saver()
    saver_seqgan = tf.train.Saver()

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    gen_data_loader.create_batches(positive_file) #把data load進來

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training Generator...') #MLE
    log.write('pre-training generator...\n') 
    for epoch in range(PRE_GEN_EPOCH_NUM):
        s = time.time()
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        
        # detect best model
        best = 1000
        if loss < best:
            saver_gen.save(sess,"model/pretrain_gen_best")

        if epoch % 5 == 0:
            print('pre-train epoch: ', epoch, 'loss: ', loss, "time: ", time.time()-s)
            log.write('epoch:\t'+ str(epoch) + '\tloss:\t' + str(loss) + '\n')

    # pre-train discriminator
    print('Start pre-training discriminator...')
    log.write('pre-training discriminator...\n') 
    
    # Train 3 epoch on the generated data and do this for 50 times
    for epoch in range(PRE_DIS_EPOCH_NUM):
        s = time.time()
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _,acc = sess.run([discriminator.train_op,discriminator.accuracy], feed)

        best = 0
        if acc > best:
            saver_dis.save(sess, "./model/pretrain_dis_best")
            best = acc

        print("pre-train epoch: ", epoch, " acc: ", acc," time: ", time.time()-s)
        log.write("epoch:\t" + str(epoch) + "\tacc:\t" + str(acc) + "\n")

    rollout = ROLLOUT(generator, 0.8)

    print( '#########################################################################')
    print( 'Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        s = time.time()

        for it in range(ADV_GEN_TIME):
            samples = generator.generate(sess) # 一條seq
            rewards = rollout.get_reward(sess, samples, 16, discriminator) #MC search
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed) # do policy gradient

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1: # cal NLL
            avg = np.mean(np.sum(rewards, axis=1), axis=0) / SEQ_LENGTH
         
            #print('total_batch: ', total_batch, 'average reward: ', avg)
            log.write('epoch:\t' + str(total_batch) + '\treward:\t' + str(avg) + '\n')

            saver_seqgan.save(sess, "./model/seq_gan", global_step=total_batch)

        # Update roll-out parameters
        rollout.update_params() # train G

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

        print('epoch: ', total_batch, 'average reward: ', avg," time: ",time.time()-s)
    
    log.close() 

    # generate examples
    print("Training Finished, starting to generating test")
    generate_samples(sess, generator, BATCH_SIZE, test_num,generate_file)
    
    print("Finish")
Ejemplo n.º 17
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    with open('data/ihaiku.pickle', 'rb') as f:
        haiku_list = pickle.load(f)
    #usew2v---------------------------------------------------------------------------------------------
    with open('data/index.pickle', 'rb') as f:
        index = pickle.load(f)
    vocab_size = len(index)
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        select_haikus(haiku_list, generated_num, positive_file)
        gen_data_loader.create_batches(positive_file)
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            print('pre-train epoch ', epoch)
            buffer = 'epoch:\t' + str(epoch) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        select_haikus(haiku_list, generated_num, positive_file)
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            kigos = select_kigos(kigo_list, BATCH_SIZE)
            samples, rate = generator.generate_with_rate(sess, kigos)
            rewards = rollout.get_reward(sess, samples, 16, discriminator,
                                         rate)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            select_haikus(haiku_list, generated_num, positive_file)
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

        # Test
        print(
            'total_batch:',
            total_batch,
        )
        if total_batch - 1 % 50 == 0:
            output_file = 'result/result_{0:04d}_epoch.txt'.format(total_batch)
            generate_samples_with_pred(sess, generator, discriminator,
                                       BATCH_SIZE, generated_num, output_file)
            buffer = 'epoch:\t' + str(total_batch) + '\n'
            log.write(buffer)

    log.close()
Ejemplo n.º 18
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    with open(true_file, 'r') as f_pos:
        file_contents = f_pos.read().splitlines()
        file_contents = [content.split() for content in file_contents]
        tokens = set([item for sublist in file_contents for item in sublist])
        # tokens = set(file_contents)

    pad_idx = len(tokens)
    vocab_size = pad_idx + 1

    token2idx = dict((token, i) for i, token in enumerate(tokens))
    idx2token = dict((i, token) for i, token in enumerate(tokens))
    idx2token[pad_idx] = " "
    load_positive(true_file, positive_file, token2idx, pad_idx)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing

    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    target_params = cPickle.load(open('save/target_params.pkl', 'rb'),
                                 encoding='latin1')
    target_params[0] = np.random.random([vocab_size, 32]).astype(np.float32)
    target_params[13] = np.random.random([32, vocab_size]).astype(np.float32)
    target_params[14] = np.random.random([
        vocab_size,
    ]).astype(np.float32)

    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                              SEQ_LENGTH, START_TOKEN,
                              target_params)  # The oracle model

    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    gen_data_loader.create_batches(positive_file, SEQ_LENGTH)

    # log file that stores progress
    log = open('save/experiment-log.txt', 'w')

    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')

    all_pre_train_losses = []
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        all_pre_train_losses.append(loss)

    plt.plot(all_pre_train_losses)
    plt.savefig('pre_train_losses_plot.png')

    gen_outfile = 'save/generated_by_generator_after_' + str(
        PRE_EPOCH_NUM) + '_' + str(datetime.datetime.now()) + '_epochs.txt'

    generate_samples(sess, generator, BATCH_SIZE, generated_num, gen_outfile,
                     idx2token)

    checksyntax.check_code(log, gen_outfile)

    # if epoch % 5 == 0:
    #     generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    #     likelihood_data_loader.create_batches(eval_file, SEQ_LENGTH)
    #     test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    #     print('pre-train epoch ', epoch, 'test_loss ', test_loss)
    #     buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
    #     log.write(buffer)

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for i in range(50):
        print("discriminator pre train epoch : ", i)
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file,
                                        SEQ_LENGTH)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):

                x_batch, y_batch = dis_data_loader.next_batch()

                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    gen_outfile = 'save/generated_by_generator_after_discriminator_training_' + str(
        datetime.datetime.now) + '.txt'

    generate_samples(sess, generator, BATCH_SIZE, generated_num, gen_outfile,
                     idx2token)

    checksyntax.check_code(log, gen_outfile)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        print("total_batch : ", total_batch)
        if total_batch % 20 == 0:
            file_name = 'save/output_batch_' + str(total_batch) + '.txt'
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             file_name, idx2token)

            checksyntax.check_code(log, file_name)

        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)
        # Test
    #     if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
    #         generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    #         likelihood_data_loader.create_batches(eval_file)
    #         test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    #         buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
    #         print('total_batch: ', total_batch, 'test_loss: ', test_loss)
    #         log.write(buffer)
    #
    # Update roll-out parameters
        rollout.update_params()
        #
        # Train the discriminator
        for _ in range(1):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file,
                                            SEQ_LENGTH)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
        final_gen_file = 'save/final_output.txt'
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         final_gen_file, idx2token)

        checksyntax.check_code(log, final_gen_file)

    #     with open('save/output.txt','r') as f:
    #         with open('save/output_word.txt','w') as fout:
    #             for line in f:
    #                 line = line.strip()
    #                 line = line.split()
    #                 word_line = ''.join([idx2token[int(x)] for x in line])
    #                 fout.write(word_line + '\n')
    #
    log.close()
Ejemplo n.º 19
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 5000

    best_score = 9.1
    generator = get_trainable_model(vocab_size)
    # target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # generate_samples(sess, target_lstm, 64, 10000, positive_file)
    ################################################################
    positive_data = np.load(positive_file).tolist()
    gen_data_loader.create_batches(positive_data)
    references = load_references(positive_data)

    log = open('log/pg_experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            #     generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            #     likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            # print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            # buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            # log.write(buffer)
            print 'pre-train epoch ', epoch, 'loss ', loss
            buffer = str(epoch) + ' ' + str(loss) + '\n'
            log.write(buffer)

    # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    # likelihood_data_loader.create_batches(eval_file)
    # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    # buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    print 'After pre-train epoch ', loss
    buffer = str(loss) + '\n'
    log.write(buffer)

    # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    # likelihood_data_loader.create_batches(eval_file)
    # significance_test(sess, target_lstm, likelihood_data_loader, 'significance/supervise.txt')

    rollout = ROLLOUT(generator, references)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            print 'start calculating BLEU...'
            rewards = rollout.get_reward(sess, samples, 1, (1.0 / 3, 1.0 / 3, 1.0 / 3))
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss], feed_dict=feed)

            # if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            #     generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            #     likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            # buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            # print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            # log.write(buffer)

            # if test_loss < best_score:
            #     best_score = test_loss
            #     print 'best score: ', test_loss
            #     significance_test(sess, target_lstm, likelihood_data_loader, 'significance/pg_bleu.txt')
            print('Current loss:' + str(total_batch) + ':' + str(g_loss))
        rollout.update_params()

    log.close()

    generate_samples(sess, generator, BATCH_SIZE, 100, final_trans_file)
Ejemplo n.º 20
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    # with open('./save/target_params.pkl','rb') as f:
    #     print('readable:',f.readable())
    #     target_params=pickle.load(f)
    target_params = pickle.load(open('save/target_params.pkl','rb'),encoding='ISO-8859-1')
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                                filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print ('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print ('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)

    print ('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print ('#########################################################################')
    print ('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
    print(test_loss)
Ejemplo n.º 21
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 5000
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = get_trainable_model(vocab_size)

    # oracle model : target lstm
    # target_params = cPickle.load(open('save/target_params.pkl'))
    # target_lstm = TARGET_LSTM(vocab_size, 64, 32, 32, SEQ_LENGTH, 0, target_params)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(
            sequence_length=SEQ_LENGTH,
            num_classes=2,
            vocab_size=vocab_size,
            embedding_size=dis_embedding_dim,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [param for param in tf.trainable_variables() if 'discriminator' in param.name]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss, cnn_params, aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars, global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # czq
    # generate real data
    # generate_samples(sess, target_lstm, 64, 10000, positive_file)

    # store real data for next step
    positive_data = np.load(positive_file).tolist()
    gen_data_loader.create_batches(positive_data)

    log = open('log/seq_mle_experiment-log.txt', 'w')
    #  pre-train generator
    print '#########################################################################'
    print 'Start pre-training generator...'
    log.write('pre-training...\n')

    for epoch in xrange(PRE_EPOCH_NUM):
        # print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            # likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            # print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            # buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            buffer = 'pre-trained generator:' + str(epoch) + ' ' + str(loss)
            print(buffer)
            log.write(buffer + '\n')

    # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    # likelihood_data_loader.create_batches(eval_file)
    # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    # buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    buffer = 'After pre-training:' + ' ' + str(loss)
    print(buffer)
    log.write(buffer + '\n')

    # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    # likelihood_data_loader.create_batches(eval_file)
    # significance_test(sess, target_lstm, likelihood_data_loader, 'significance/supervise.txt')

    # test purpose only
    generate_samples(sess, generator, BATCH_SIZE, 100, final_trans_file_mle)

    # exit(0)

    print 'Start pre-training discriminator...'
    for _ in range(dis_alter_epoch):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(
            zip(dis_x_train, dis_y_train), dis_batch_size, dis_num_epochs
        )

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except Exception as e:
                # print str(e)
                raise

        loss = sess.run(cnn.loss, feed)
        buffer = 'pre-train discriminator' + ' ' + str(loss)
        print buffer
        log.write(buffer + '\n')

    rollout = ROLLOUT(generator, 0.8)
    print('Before GAN')
    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    # for tensorboard
    # writer = tf.summary.FileWriter('./tb_logs', graph=tf.get_default_graph())

    for total_batch in range(TOTAL_BATCH):
        print 'progress', total_batch, '/', TOTAL_BATCH
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss, pre_loss = sess.run([generator.g_updates, generator.g_loss, generator.pretrain_loss],
                                           feed_dict=feed)
            buffer = 'G-step:' + str(TRAIN_ITER) + ':' + str(g_loss) + '|' + str(pre_loss)
            log.write(buffer + '\n')
            print(buffer)
            # if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            #     generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            #     likelihood_data_loader.create_batches(eval_file)
            #     # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            #     # buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            #     # print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            #     log.write(buffer)

            # if test_loss < best_score:
            #     best_score = test_loss
            #     print 'best score: ', test_loss
            #     significance_test(sess, target_lstm, likelihood_data_loader, 'significance/seqgan.txt')

        rollout.update_params()

        # generate for discriminator
        print('Start training discriminator')
        log.write('training discriminator...\n')

        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

            loss = sess.run(cnn.loss, feed)
            buffer = 'discriminator' + ' ' + str(loss)
            print buffer
            log.write(buffer + '\n')

    log.close()

    # save the model
    # saver = tf.train.Saver({"gen": generator})
    # saver.save(sess, 'my-model')

    # generate samples
    generate_samples(sess, generator, BATCH_SIZE, 100, final_trans_file_seqgan)
Ejemplo n.º 22
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # data loaders declaration
    # loaders for generator, discriminator, and additional validation data loader
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    eval_data_loader = Gen_Data_loader(BATCH_SIZE)

    # define generator and discriminator
    # general structures are same with the original model
    # learning rates for generator needs heavy tuning for general use
    # l2 reg for D & G also affects performance
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN, GENERATOR_LR, REWARD_GAMMA)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    # VRAM limitation for efficient deployment
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)
    sess.run(tf.global_variables_initializer())
    # define saver
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=1)
    # generate real data from the true dataset
    gen_data_loader.create_batches(positive_file)
    # generate real validation data from true validation dataset
    eval_data_loader.create_batches(valid_file)

    time = str(datetime.datetime.now())[:-7]
    log = open('save/experiment-log-conditional' + str(time) + '.txt', 'w')
    log.write(str(config) + '\n')
    log.write('D loss: original\n')
    log.flush()

    #summary_writer = tf.summary.FileWriter('save/tensorboard/', graph=tf.get_default_graph())

    if config['pretrain'] == True:
        #  pre-train generator
        print 'Start pre-training...'
        log.write('pre-training...\n')
        for epoch in xrange(PRE_GEN_EPOCH):
            # calculate the loss by running an epoch
            loss = pre_train_epoch_condtional(sess, generator, gen_data_loader)

            # measure bleu score with the validation set
            bleu_score = calculate_bleu(sess, generator, eval_data_loader)

            # since the real data is the true data distribution, only evaluate the pretraining loss
            # note the absence of the oracle model which is meaningless for general use
            buffer = 'pre-train epoch: ' + str(
                epoch) + ' pretrain_loss: ' + str(loss) + ' bleu: ' + str(
                    bleu_score)
            print(buffer)
            log.write(buffer + '\n')
            log.flush()

            # generate 5 test samples per epoch
            # it automatically samples from the generator and postprocess to midi file
            # midi files are saved to the pre-defined folder
            if epoch == 0:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                POST.main(negative_file, 5,
                          str(-1) + '_vanilla_', 'midi_conditional')
            elif epoch == PRE_GEN_EPOCH - 1:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                POST.main(negative_file, 5,
                          str(-PRE_GEN_EPOCH) + '_vanilla_',
                          'midi_conditional')

        print 'Start pre-training discriminator...'
        # Train 3 epoch on the generated data and do this for 50 times
        # this trick is also in spirit of the original work, but the epoch strategy needs tuning
        for epochs in range(PRE_DIS_EPOCH):
            generate_samples_conditional_v2(sess, gen_data_loader, generator,
                                            BATCH_SIZE, generated_num,
                                            negative_file)
            D_loss = 0
            for _ in range(3):
                dis_data_loader.load_train_data(positive_file, negative_file)
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
                    D_loss += discriminator.loss.eval(feed, session=sess)
            buffer = 'epoch: ' + str(epochs + 1) + '  D loss: ' + str(
                D_loss / dis_data_loader.num_batch / 3)
            print(buffer)
            log.write(buffer + '\n')
            log.flush()

        # save the pre-trained checkpoint for future use
        # if one wants adv. training only, comment out the pre-training section after the save
        save_checkpoint(sess, saver, PRE_GEN_EPOCH, PRE_DIS_EPOCH)

    # define rollout target object
    # the second parameter specifies target update rate
    # the higher rate makes rollout "conservative", with less update from the learned generator
    # we found that higher update rate stabilized learning, constraining divergence of the generator
    rollout = ROLLOUT(generator, ROLLOUT_UPDATE_RATE)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    if config['pretrain'] == False:
        # load checkpoint of pre-trained model
        load_checkpoint(sess, saver)

    # 0.001 to 0.01
    if config['x10adv_g'] == True:
        generator.learning_rate *= 10

    for total_batch in range(TOTAL_BATCH):
        G_loss = 0
        # Train the generator for one step
        for it in range(epochs_generator):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, config['rollout_num'],
                                         discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)
            G_loss += generator.g_loss.eval(feed, session=sess)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        D_loss = 0
        for _ in range(epochs_discriminator):
            generate_samples_conditional_v2(sess, gen_data_loader, generator,
                                            BATCH_SIZE, generated_num,
                                            negative_file)
            for _ in range(config['epochs_discriminator_multiplier']):
                dis_data_loader.load_train_data(positive_file, negative_file)
                dis_data_loader.reset_pointer()

                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
                    D_loss += discriminator.loss.eval(feed, session=sess)

        # measure stability and performance evaluation with bleu score
        bleu_score = calculate_bleu(sess, generator, eval_data_loader)
        buffer = 'epoch: ' + str(total_batch + 1) + \
                 ',  G_adv_loss: %.12f' % (G_loss / epochs_generator) + \
                 ',  D loss: %.12f' % (D_loss / epochs_discriminator / config['epochs_discriminator_multiplier']) + \
                 ',  bleu score: %.12f' % bleu_score
        print(buffer)
        log.write(buffer + '\n')
        log.flush()

        if config['infinite_loop'] is True:
            if bleu_score < config['loop_threshold']:
                buffer = 'Mode collapse detected, restarting from pretrained model...'
                print(buffer)
                log.write(buffer + '\n')
                log.flush()
                load_checkpoint(sess, saver)

        # generate random test samples and postprocess the sequence to midi file
        #generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

        # instead of the above, generate samples conditionally
        # randomly sample a batch
        # rng = np.random.randint(0, high=gen_data_loader.num_batch, size=1)
        # random_batch = np.squeeze(gen_data_loader.sequence_batch[rng])
        generate_samples_conditional_v2(sess, gen_data_loader, generator,
                                        BATCH_SIZE, generated_num,
                                        negative_file)
        POST.main(negative_file, 5,
                  str(total_batch) + '_vanilla_', 'midi_conditional')
    log.close()
Ejemplo n.º 23
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    vocab_dict, vocab_res = data_utils.load_vocab('./vocab.txt')
    data = data_utils.load_data('data.pkl')
    keywords = data_utils.load_data('kwd.pkl')

    print(len(keywords))
    print(len(data))
    # data = data[:1000]
    tn_size = int(len(data) * 0.8)
    tn_loader = DataLoader(data[:tn_size], keywords[:tn_size], BATCH_SIZE)
    ts_loader = DataLoader(data[tn_size:], keywords[tn_size:], BATCH_SIZE)
    print('data 个数: ', len(data))
    vocab_size = len(vocab_dict)
    SEQ_LENGTH = data.shape[1]
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, 1604, START_TOKEN)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(tf.global_variables())
    last_epoch = load_model(sess, saver, model_dir)

    if last_epoch <= 0:
        # pre-train generator
        print('Start pre-training...')
        for epoch in range(PRE_EPOCH_NUM):
            loss = pre_train_epoch(sess, generator, tn_loader)
            if epoch % 5 == 0:
                test_loss = target_loss(sess, generator, ts_loader)
                print('pre-train epoch ', epoch, 'train loss', loss,
                      'test_loss ', test_loss)

        print('Start pre-training discriminator...')
        # Train 3 epoch on the generated data and do this for 50 times
        for epoch_i in range(DIS_PRE_EPOCH_NUM):
            # --------- changed by zhoujifa -------------- #
            gen_data = generate_samples(sess, generator, BATCH_SIZE,
                                        generated_num, keywords)

            print(gen_data.shape)
            print(np.shape(data[:tn_size]))

            dis_loader = DisDataLoader(data[:tn_size], gen_data, BATCH_SIZE)

            for _ in range(3):
                losses = []
                for index, (x_batch,
                            y_batch) in enumerate(dis_loader.next_batch()):
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, dis_loss = sess.run(
                        [discriminator.train_op, discriminator.loss], feed)
                    losses.append(dis_loss)
                    if index % 1000 == 0:
                        print('\tepoch: {}, batch index : {}, loss: {}'.format(
                            epoch_i, index, dis_loss))
                print('epoch: {}, loss: {}'.format(epoch_i, np.mean(losses)))
    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    for total_batch in range(last_epoch + 1, TOTAL_BATCH):
        # Train the generator for one step
        for it in range(10):
            """
            changed by zhoujifa
            """
            kwd = select_keywords(keywords)
            samples = generator.generate(sess, kwd)
            rewards = rollout.get_reward(sess, samples, kwd, 16, discriminator)
            feed = {
                generator.x: samples,
                generator.rewards: rewards,
                generator.keywords: kwd
            }
            _ = sess.run(generator.g_updates, feed_dict=feed)
            """
            end
            """

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            test_loss = target_loss(sess, generator, ts_loader)
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
        # Update roll-out parameters

        rollout.update_params()

        # Train the discriminator
        for _ in range(1):
            # ------ changed by zhoujifa ---------- #
            gen_data = generate_samples(sess, generator, BATCH_SIZE,
                                        generated_num, keywords)
            dis_loader = DisDataLoader(data[:tn_size], gen_data, BATCH_SIZE)

            for epoch in range(1):
                losses = []
                for index, (x_batch,
                            y_batch) in enumerate(dis_loader.next_batch()):
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, dis_loss = sess.run(
                        [discriminator.train_op, discriminator.loss], feed)
                    losses.append(dis_loss)
                    if index % 1000 == 0:
                        print('\tepoch: {}, batch index : {}, loss: {}'.format(
                            epoch, index, dis_loss))
                print('\tepoch: {}, loss: {}'.format(epoch, np.mean(losses)))
        saver.save(sess, model_dir + 'poetry.module', global_step=total_batch)
    for i in range(int(5)):
        if i > len(samples):
            break
        arr = samples[i]
        poem = ''
        for index in arr:
            if index != data_utils.EOS_ID:
                poem += vocab_res[index]
        print(poem)
    sess.close()
Ejemplo n.º 24
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 5000

    best_score = 9.1
    generator = get_trainable_model(vocab_size)
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer)

    generate_samples(sess, target_lstm, 64, 10000, positive_file)
    ################################################################
    gen_data_loader.create_batches(positive_file)
    references = load_references(positive_file)

    log = open('log/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    significance_test(sess, target_lstm, likelihood_data_loader, 'significance/supervise.txt')

    rollout = ROLLOUT(generator, references)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            print 'start calculating BLEU...'
            rewards = rollout.get_reward(sess, samples, 1, (1.0 / 3, 1.0 / 3, 1.0 / 3))
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss], feed_dict=feed)

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

            if test_loss < best_score:
                best_score = test_loss
                print 'best score: ', test_loss
                significance_test(sess, target_lstm, likelihood_data_loader, 'significance/pg_bleu.txt')

        rollout.update_params()

    log.close()
Ejemplo n.º 25
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # TODO: I changed this.  Why was this asserted?  Was it just to ensure the replication
    # of results?  Or is zero important otherwise?
    # Changed because 0 is a bad start token for our data.  (cannot have home label=0)
    # assert START_TOKEN == 0

    # set up logging
    log_fpath = logger.get_experiment_log_filepath()

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) # For testing
    vocab_size = VOCAB_SIZE
    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size,
                                    embedding_size=dis_embedding_dim, filter_sizes=dis_filter_sizes,
                                    num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    if not USE_GPU:
        # Prevent the environment from seeing the available GPUs (to avoid error on matlaber cluster)
        import os
        os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    gen_data_loader.create_batches(real_file)

    #  pre-train generator
    logger.write_log(log_fpath, 'pre-training generator...')
    for epoch in xrange(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            logger.write_log(log_fpath, 'generator loss:')
            logger.log_progress(log_fpath, epoch, loss)
            generate_samples(sess, generator, BATCH_SIZE, eval_generated_num, eval_file.format('pretrain'))

    logger.write_log(log_fpath, 'Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for i in range(50):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, fake_file)
        dis_data_loader.load_train_data(real_file, fake_file)
        # dis_data_loader.load_train_data(positive_file, negative_file)
        logger.write_log(log_fpath, 'epoch iterator:  %s / 50' % i)
        for j in range(3):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _d_train_output = sess.run(discriminator.train_op, feed)

    logger.write_log(log_fpath, 'finished pre-training discriminator')
    rollout = ROLLOUT(generator, 0.8)

    logger.write_log(log_fpath, 'Start Adversarial Training...')
    g_steps = 1
    d_steps = 1
    k = 10
    for batch in range(TOTAL_BATCH):
        buff = 'batch %s/%s' % (batch, TOTAL_BATCH)
        logger.write_log(log_fpath, buff)
        # Train the generator for one step
        for it in range(g_steps):
            samples = generator.generate(sess)
            rollout_num = 16  # TODO: experiment with this value
            rewards = rollout.get_reward(sess, samples, rollout_num, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if batch % 5 == 0 or batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, eval_generated_num, eval_file.format(batch))
            logger.write_log(log_fpath, 'generated some more eval samples...')

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(d_steps):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, fake_file)
            dis_data_loader.load_train_data(real_file, fake_file)

            for _ in range(k):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    logger.write_log(log_fpath, 'I\'M DONE')
Ejemplo n.º 26
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Generator_Data_Loader(BATCH_SIZE)
    # For testing
    likelihood_data_loader = Generator_Data_Loader(BATCH_SIZE)
    vocab_size = 19851
    dis_data_loader = Discriminator_Data_Loader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    # target_params = pickle.load(open('data/target_params_py3.pkl', 'rb'))
    # The oracle model - synthetic data
    # target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params)

    discriminator = Discriminator(seq_len=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  emb_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    # generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('data/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training...')
    log.write('Pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        start = time.time()
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        print("Epoch ", epoch, " Loss: ", loss)
        print("Per epoch time consumed: ", time.time() - start)

        # if epoch % 5 == 0:
        #     generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
        #     likelihood_data_loader.create_batches(eval_file)
        #     test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        #     print(f'Pre-train epoch: {epoch}, Test_loss: {test_loss}')
        #     buffer = "Epoch:\t"+ str(epoch) + "\tNeg-Log Likelihood:\t" + str(test_loss) + "\n"
        #     log.write(buffer)

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in tqdm(range(50)):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('Adversarial training...\n')
    for total_batch in tqdm(range(TOTAL_BATCH)):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            # likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = "Epoch:\t" + str(total_batch) + "\tReward:\t" + str(
                rewards) + "\n"
            print(f'Total Batch: {total_batch}, Reward {rewards}')
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    # Final generation
    print("Writing final results to test file")
    test_file = "data/final.txt"
    generate_samples(sess, generator, BATCH_SIZE, generated_num, test_file)
    print("Finished")

    log.close()
Ejemplo n.º 27
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 68
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    # load generator with parameters
    generator = get_trainable_model(vocab_size)
    target_params = initialize_parameters(vocab_size)

    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                              SEQ_LENGTH, START_TOKEN, target_params)

    # CNNs
    with tf.variable_scope('discriminator'):
        cnn = TextCNN(sequence_length=SEQ_LENGTH,
                      num_classes=2,
                      vocab_size=vocab_size,
                      embedding_size=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [
        param for param in tf.trainable_variables()
        if 'discriminator' in param.name
    ]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss,
                                                         cnn_params,
                                                         aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars,
                                                 global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # generate_samples(sess, target_lstm, 64, 10000, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open(logpath, 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            file_name = 'target_generate/pretrain_epoch' + str(epoch) + '.pkl'
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             file_name)
            likelihood_data_loader.create_batches(file_name)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            log.write(buffer)

            if epoch % 100 != 0:
                os.remove(file_name)

    file_name = 'target_generate/pretrain_finished.pkl'
    generate_samples(sess, generator, BATCH_SIZE, generated_num, file_name)
    likelihood_data_loader.create_batches(file_name)
    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    log.write(buffer)

    file_name = 'target_generate/supervise.pkl'
    generate_samples(sess, generator, BATCH_SIZE, generated_num, file_name)
    likelihood_data_loader.create_batches(file_name)
    significance_test(sess, target_lstm, likelihood_data_loader,
                      'significance/supervise.txt')

    os.remove(file_name)

    print 'Start training discriminator...'
    for i in range(dis_alter_epoch):
        print 'dis_alter_epoch : ' + str(i)
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train),
                                                 dis_batch_size,
                                                 dis_num_epochs)

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except ValueError:
                pass

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss],
                                 feed_dict=feed)

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:

            file_name = 'target_generate/reinforce_batch' + str(
                total_batch) + '.pkl'

            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             file_name)
            likelihood_data_loader.create_batches(file_name)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

            if total_batch % 50 != 0:
                os.remove(file_name)

            if test_loss < best_score:
                best_score = test_loss
                print 'best score: ', test_loss
                significance_test(sess, target_lstm, likelihood_data_loader,
                                  'significance/seqgan.txt')

        rollout.update_params()

        # generate for discriminator
        print 'Start training discriminator'
        for _ in range(5):
            # for _ in range(2):

            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(
                positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(
                zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

    log.close()
Ejemplo n.º 28
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0
    start_candidates = []
    p_start_candidates = []
    with open(START_TOKEN_CANDIDATES_PATH) as fin:
        for l in fin:
            token = l.strip().split(",")
            start_candidates.append(token[0])
            p_start_candidates.append(float(token[1]))
    start_candidates = np.array(start_candidates, dtype=np.int32)
    p_start_candidates = np.array(p_start_candidates, dtype=np.float32)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
#    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    vocab_size = VOCAB_SIZE
    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    #target_params = pickle.load(open('save/target_params_py3.pkl', "rb"))
    #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                                filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        #if epoch % 5 == 0:
        #    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
        #    likelihood_data_loader.create_batches(eval_file)
        #    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        #    print('pre-train epoch ', epoch, 'test_loss ', test_loss)
        #    buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
        #    log.write(buffer)
        buffer = 'epoch:%i\tnll:%f'%(epoch, loss)
        print(buffer)
        log.write(buffer+"\n")

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for idx in range(50):
        print(idx)
        start_tokens = get_start_token(start_candidates, generated_num, p=p_start_candidates)
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file+".pre%i"%idx, start_tokens)
        dis_data_loader.load_train_data(positive_file, negative_file+".pre%i"%idx)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    saver = tf.train.Saver()
    for total_batch in range(TOTAL_BATCH):
        print(total_batch)
        # Train the generator for one step
        for it in range(1):
            start_token = get_start_token(start_candidates, BATCH_SIZE, p=p_start_candidates)
            samples = generator.generate(sess, start_token)
            rewards = rollout.get_reward(sess, samples, 16, discriminator, start_token)
            feed = {generator.x: samples, generator.rewards: rewards, generator.start_token: start_token}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        #if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
        #    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
        #    likelihood_data_loader.create_batches(eval_file)
        #    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        #    buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
        #    print('total_batch: ', total_batch, 'test_loss: ', test_loss)
        #    log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            start_tokens = get_start_token(start_candidates, generated_num, p=p_start_candidates)
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file+".gan%i"%total_batch, start_tokens)
            dis_data_loader.load_train_data(positive_file, negative_file+".gan%i"%total_batch)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
        if total_batch % 20 == 0 or total_batch == TOTAL_BATCH - 1:
            saver.save(sess, "save/*model_%i.ckpt"%total_batch, global_step=total_batch)

    log.close()
Ejemplo n.º 29
0
def main(FLAGS):
    #########################################################################################
    #  Generator  Hyper-parameters
    ######################################################################################
    EMB_DIM = FLAGS.gen_emb_dim  # 32  # embedding dimension
    HIDDEN_DIM = FLAGS.gen_hidden_dim  # 32  # hidden state dimension of lstm cell
    SEQ_LENGTH = FLAGS.seq_len  # 20  # sequence length
    START_TOKEN = 0
    PRE_EPOCH_NUM = FLAGS.gen_pretrain_epoch_num  # 120 # supervise (maximum likelihood estimation) epochs for generator
    DISC_PRE_EPOCH_NUM = FLAGS.dis_pretrain_epoch_num  # 50 # supervise (maximum likelihood estimation) epochs for descriminator
    SEED = 88
    BATCH_SIZE = FLAGS.batch_size  #64
    gen_dropout_keep_prob = FLAGS.gen_dropout_keep_prob  # 0.75
    gen_num_recurrent_layers = FLAGS.gen_num_recurrent_layers  # 1
    gen_learning_rate = FLAGS.gen_learning_rate

    #########################################################################################
    #  Discriminator  Hyper-parameters
    #########################################################################################
    dis_embedding_dim = FLAGS.dis_emb_dim  # 64
    dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
    dis_num_filters = [
        100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160
    ]
    dis_dropout_keep_prob = 0.75
    dis_l2_reg_lambda = 0.2
    dis_batch_size = FLAGS.batch_size  #64

    #########################################################################################
    #  Basic Training Parameters
    #########################################################################################
    EXPERIMENT_NAME = FLAGS.experiment_name
    TOTAL_BATCH = FLAGS.num_epochs  # 200 #num of adversarial epochs
    positive_file = 'save/real_data_%0s.txt' % EXPERIMENT_NAME
    negative_file = 'save/generator_sample_%0s.txt' % EXPERIMENT_NAME
    eval_file = "save/eval_file_%0s" % EXPERIMENT_NAME
    generated_num = 10000  # 10000

    #########################################################################################
    #  Data configurations
    #########################################################################################
    use_real_world_data = True
    real_data_file_path = FLAGS.dataset_path  # './data/text8/text8'
    dataset_name = os.path.basename(real_data_file_path)
    base_token = FLAGS.base_token  # 'char'

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    if use_real_world_data:

        real_data_train_file = real_data_file_path + '-train'
        real_data_valid_file = real_data_file_path + '-valid'
        real_data_test_file = real_data_file_path + '-test'
        real_data_dict_file = real_data_file_path + '-{}-dict.json'.format(
            base_token)

        if not os.path.exists(real_data_train_file):
            split_text8(real_data_file_path)

        map, inv_map = create_real_data_dict(real_data_train_file,
                                             real_data_dict_file, base_token)
        vocab_size = len(map)

        if dataset_name == 'text8' and base_token == 'char':
            assert vocab_size == 27  # SORRY FOR THE HARD CODING
        elif dataset_name == 'ptb' and base_token == 'word':
            assert vocab_size == 10001  # SORRY FOR THE HARD CODING
        elif dataset_name == 'toy' and base_token == 'word':
            assert vocab_size == 8  # SORRY FOR THE HARD CODING
        elif dataset_name == 'wt2' and base_token == 'word':
            assert vocab_size == 33279  # SORRY FOR THE HARD CODING
        else:
            raise TypeError

        gen_data_loader = Gen_Data_loader_text(BATCH_SIZE,
                                               map,
                                               inv_map,
                                               seq_len=SEQ_LENGTH,
                                               token_type=base_token)
        dis_data_loader = Dis_dataloader_text(BATCH_SIZE,
                                              map,
                                              inv_map,
                                              seq_len=SEQ_LENGTH,
                                              token_type=base_token)

    else:
        gen_data_loader = Gen_Data_loader(BATCH_SIZE)
        likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
        vocab_size = 5000
        dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size,
                          BATCH_SIZE,
                          EMB_DIM,
                          HIDDEN_DIM,
                          SEQ_LENGTH,
                          START_TOKEN,
                          dropout_keep_prob=gen_dropout_keep_prob,
                          num_recurrent_layers=gen_num_recurrent_layers)

    if not use_real_world_data:
        target_params = pickle.load(open('save/target_params.pkl'))
        target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                                  SEQ_LENGTH, START_TOKEN,
                                  target_params)  # The oracle model

    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.3
    sess = tf.Session(config=config)
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=999999)
    sess.run(tf.global_variables_initializer())

    if use_real_world_data:
        # gen_data_loader.create_batches(real_data_train_file)
        pass
    else:
        # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
        generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                         positive_file)
        gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        print("start epoch %0d" % epoch)

        # update learning rate
        if epoch > 5:
            gen_learning_rate /= FLAGS.gen_learning_decay * 1.

        if epoch % FLAGS.save_each_epochs == 0:
            print(
                '#########################################################################'
            )
            print('saving model...')
            save_file = os.path.join(
                '.', 'ckp', EXPERIMENT_NAME + '_pretrain_epoch_%0d' % epoch,
                EXPERIMENT_NAME + '_pretrain_epoch_%0d' % epoch)
            saver.save(sess, save_file)

        if use_real_world_data:
            gen_data_loader.create_batches(real_data_train_file,
                                           limit_num_samples=generated_num)

        loss = pre_train_epoch(sess, generator, gen_data_loader,
                               gen_learning_rate)
        if epoch % 1 == 0:
            if use_real_world_data:
                generate_real_data_samples(
                    sess, generator, BATCH_SIZE, generated_num,
                    eval_file + "_epoch_%0d.txt" % epoch, inv_map, base_token)
                test_loss = 0  # FIXME - TEMP
            else:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 eval_file)
                likelihood_data_loader.create_batches(eval_file)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)

            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for epoch in range(DISC_PRE_EPOCH_NUM):
        print("start epoch %0d" % epoch)
        if use_real_world_data:
            generate_real_data_samples(sess, generator, BATCH_SIZE,
                                       generated_num, negative_file, inv_map,
                                       base_token)
            dis_data_loader.load_train_data(real_data_train_file,
                                            negative_file)
        else:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        print("start epoch %0d" % total_batch)

        if total_batch % FLAGS.save_each_epochs == 0:
            print(
                '#########################################################################'
            )
            print('saving model...')
            save_file = os.path.join(
                '.', 'ckp', EXPERIMENT_NAME + '_epoch_%0d' % total_batch,
                EXPERIMENT_NAME + '_epoch_%0d' % total_batch)
            saver.save(sess, save_file)

        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {
                generator.x: samples,
                generator.rewards: rewards,
                generator.learning_rate: 0.01
            }
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            if not use_real_world_data:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 eval_file)
                likelihood_data_loader.create_batches(eval_file)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)
                buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                    test_loss) + '\n'
                print('total_batch: ', total_batch, 'test_loss: ', test_loss)
                log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):

            if use_real_world_data:
                generate_real_data_samples(sess, generator, BATCH_SIZE,
                                           generated_num, negative_file,
                                           inv_map, base_token)
                dis_data_loader.load_train_data(real_data_train_file,
                                                negative_file)
            else:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    print(
        '#########################################################################'
    )
    print('saving model...')
    save_file = os.path.join('.', 'ckp', EXPERIMENT_NAME, EXPERIMENT_NAME)
    saver.save(sess, save_file)

    #
    # print '#########################################################################'
    # print 'Start Language Model Evaluation...'
    # test_data_loader = Gen_Data_loader_text(BATCH_SIZE,map,inv_map)
    # test_data_loader.create_batches(real_data_test_file)
    # language_model_evaluation(sess,generator, test_data_loader)

    log.close()
Ejemplo n.º 30
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    # assert START_TOKEN == 0

    vocab_size = NUM_EMB
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          MAX_LENGTH, START_TOKEN)
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                              MAX_LENGTH, 0)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(sequence_length=MAX_LENGTH,
                      num_classes=2,
                      vocab_size=vocab_size,
                      embedding_size=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [
        param for param in tf.trainable_variables()
        if 'discriminator' in param.name
    ]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss,
                                                         cnn_params,
                                                         aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars,
                                                 global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    def train_discriminator():
        if D_WEIGHT == 0:
            return 0, 0

        negative_samples = generate_samples(sess, generator, BATCH_SIZE,
                                            POSITIVE_NUM)

        #        global positive_samples
        #        pos_new=positive_samples
        # random 10% of positive samples are labeled negatively to weaken generator and avoid collapsing training
        #        random.shuffle(pos_new)
        #        length=len(pos_new)
        #        fake_neg_number= int(0.05*length)
        #        fake_neg= pos_new[:fake_neg_number]
        #        pos_new=pos_new[fake_neg_number:]

        #       negative_samples+=fake_neg
        #      random.shuffle(negative_samples)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_samples, negative_samples)
        dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train),
                                                 dis_batch_size,
                                                 dis_num_epochs)

        for batch in dis_batches:
            x_batch, y_batch = zip(*batch)
            feed = {
                cnn.input_x: x_batch,
                cnn.input_y: y_batch,
                cnn.dropout_keep_prob: dis_dropout_keep_prob
            }
            _, step, loss, accuracy = sess.run(
                [dis_train_op, dis_global_step, cnn.loss, cnn.accuracy], feed)
        print('\tD loss  :   {}'.format(loss))
        print('\tAccuracy: {}'.format(accuracy))
        return loss, accuracy

    # Pretrain is checkpointed and only execcutes if we don't find a checkpoint
#    saver = tf.train.Saver()

# We check previous session and pretrain is checkpointed and only execcutes if we don't find a checkpoint
    saver = tf.train.Saver()

    #check previous session
    prev_sess = False
    ckpt_dir = 'checkpoints/{}'.format(PREFIX)

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
#    ckpt_file = os.path.join(ckpt_dir, ckpt_dir + '_model')   #old checkpoint
    ckpt_file = os.path.join(
        ckpt_dir, PREFIX + '_model_'
    )  #new checkpoint iterate over checkpoints to find largest total a

    nbatches_max = 0
    for i in range(500):  #maximal number of batches iterations is 500
        if os.path.isfile(ckpt_file + str(i) +
                          '.meta'):  #and params["LOAD_PREV_SESS"]
            nbatches_max = i

#end try find max checkpoint
    ckpt_file = ckpt_file + str(nbatches_max) + '.meta'
    if params["LOAD_PREV_SESS"]:  # and os.path.isfile(ckpt_file):
        saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir))

        print('Previous session loaded from previous checkpoint {}'.format(
            ckpt_file))
        prev_sess = True
    else:
        if params["LOAD_PREV_SESS"]:
            print('\t* No previous session data found as {:s}.'.format(
                ckpt_file))
        else:
            print('\t* LOAD_PREV_SESS was set to false.')

    if prev_sess == False:
        #check pretraining
        ckpt_dir = 'checkpoints/{}_pretrain'.format(PREFIX)
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        ckpt_file = os.path.join(ckpt_dir, 'pretrain_ckpt')
        if os.path.isfile(ckpt_file + '.meta') and params["LOAD_PRETRAIN"]:
            saver.restore(sess, ckpt_file)
            print('Pretrain loaded from previous checkpoint {}'.format(
                ckpt_file))
        else:
            if params["LOAD_PRETRAIN"]:
                print('\t* No pre-training data found as {:s}.'.format(
                    ckpt_file))
            else:
                print('\t* LOAD_PRETRAIN was set to false.')

            sess.run(tf.global_variables_initializer())
            pretrain(sess, generator, target_lstm, train_discriminator)
            path = saver.save(sess, ckpt_file)
            print('Pretrain finished and saved at {}'.format(path))


#end loading previous session or pre-training

# create reward function
    batch_reward = make_reward(train_samples)

    rollout = ROLLOUT(generator, 0.8)

    #    nbatches_max= 30

    print(
        '#########################################################################'
    )
    print('Start Reinforcement Training Generator...')
    results_rows = []

    if nbatches_max + 1 > TOTAL_BATCH:
        print(
            ' We already trained that many batches: Check the Checkpoints folder or take a larger TOTAL_BATCH'
        )
    else:
        for nbatch in tqdm(range(nbatches_max + 1, TOTAL_BATCH)):

            #for nbatch in tqdm(range(TOTAL_BATCH)):
            results = OrderedDict({'exp_name': PREFIX})
            if nbatch % 1 == 0 or nbatch == TOTAL_BATCH - 1:
                print('* Making samples')
                if nbatch % 10 == 0:
                    gen_samples = generate_samples(sess, generator, BATCH_SIZE,
                                                   BIG_SAMPLE_NUM)
                else:
                    gen_samples = generate_samples(sess, generator, BATCH_SIZE,
                                                   SAMPLE_NUM)
                likelihood_data_loader.create_batches(gen_samples)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)
                print('batch_num: {}'.format(nbatch))
                print('test_loss: {}'.format(test_loss))
                results['Batch'] = nbatch
                results['test_loss'] = test_loss

                if test_loss < best_score:
                    best_score = test_loss
                    print('best score: %f' % test_loss)

                # results
                mm.compute_results(gen_samples, train_samples, ord_dict,
                                   results)

            print(
                '#########################################################################'
            )
            print('-> Training generator with RL.')
            print('G Epoch {}'.format(nbatch))

            for it in range(TRAIN_ITER):
                samples = generator.generate(sess)
                rewards = rollout.get_reward(sess, samples, 16, cnn,
                                             batch_reward, D_WEIGHT)
                nll = generator.generator_step(sess, samples, rewards)
                # results
                print_rewards(rewards)
                print('neg-loglike: {}'.format(nll))
                results['neg-loglike'] = nll
            rollout.update_params()

            # generate for discriminator
            print('-> Training Discriminator')
            for i in range(D):
                print('D_Epoch {}'.format(i))
                d_loss, accuracy = train_discriminator()
                results['D_loss_{}'.format(i)] = d_loss
                results['Accuracy_{}'.format(i)] = accuracy
            print('results')
            results_rows.append(results)
            if nbatch % params["EPOCH_SAVES"] == 0:
                save_results(sess, PREFIX, PREFIX + '_model_' + str(nbatch),
                             results_rows)

    # write results
        save_results(sess, PREFIX, PREFIX + '_model_' + str(nbatch),
                     results_rows)

    print('\n:*** FINISHED ***')
    return
Ejemplo n.º 31
0
def main():
    # random.seed(SEED)
    # np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing

    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator2(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN,learning_rate=0.03)


    discriminator = RNNDiscriminator2(sequence_length=SEQ_LENGTH, nrof_class=2, vocab_size=vocab_size, emb_dim=dis_embedding_dim,
                                     batch_size = dis_batch_size,hidden_dim = 2*HIDDEN_DIM, learning_rate = 0.03)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Create Saver
    saver_pretrain = tf.train.Saver(max_to_keep=10)
    saver = tf.train.Saver(max_to_keep=10)

    model_idx = 1
    fname = 'model' + str(model_idx)
    model_save_path = './Model/' + fname + '/'

    while os.path.exists(model_save_path):
        model_idx += 1
        fname = 'model' + str(model_idx)
        model_save_path = './Model/' + fname + '/'

    pre_model_save_path = './Model/' + fname + '_pre/'

    os.makedirs(model_save_path)
    os.makedirs(pre_model_save_path)
    # os.makedirs(os.path.join('./log', fname))

    pretrain_fname = fname+'_pre'

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution

    gen_data_loader.create_batches(positive_file)

    #  pre-train generator
    print 'Start pre-training...'
    early_stop_buffer = [10.]*5
    for pretrain_cnt, epoch in enumerate(xrange(PRE_EPOCH_NUM)):
        loss = pre_train_epoch(sess, generator, gen_data_loader)

        if epoch % 2 == 0:
            # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_real_file)
            test_loss = target_loss(sess, generator, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            early_stop_buffer = early_stop_buffer[1:]
            early_stop_buffer.append(test_loss)
            if all(early_stop_buffer[0] < np.asarray(early_stop_buffer[1:])):
                break

            elif all(early_stop_buffer[-1] < np.asarray(early_stop_buffer[:-1])):   # save on local min
                saver_pretrain.save(sess, os.path.join(pre_model_save_path, pretrain_fname), global_step=epoch, write_meta_graph=False)

                metagraph_filename = os.path.join(pre_model_save_path, pretrain_fname + '.meta')

                if not os.path.exists(metagraph_filename):
                    saver.export_meta_graph(metagraph_filename)

    saver.restore(sess,tf.train.latest_checkpoint(pre_model_save_path))


    print 'Start pre-training discriminator...'
    # Train 1 epoch on the generated data and do this for 50 times
    for e in range(50):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch
                    }
                _ = sess.run(discriminator.train_op, feed)
            print 'Epoch {}'.format(e)
    rollout = ROLLOUT(generator, 0.7)

    print '#########################################################################'
    print 'Start Adversarial Training...'


    early_stop_buffer = [10.] * 6
    for total_batch in range(TOTAL_BATCH):

        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, SAMP_NUM, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 2 == 0 or total_batch == TOTAL_BATCH - 1:
            # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_real_file)
            test_loss = target_loss(sess, generator, likelihood_data_loader)
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss

            # early_stop_buffer = early_stop_buffer[1:]
            # early_stop_buffer.append(test_loss)
            # if all(early_stop_buffer[0] < np.asarray(early_stop_buffer[1:])):
            #     break

            # elif all(early_stop_buffer[-1] < np.asarray(early_stop_buffer[:-1])):   # save on local min
            saver.save(sess, os.path.join(model_save_path, fname), global_step=total_batch, write_meta_graph=False)

            metagraph_filename = os.path.join(model_save_path, fname + '.meta')

            if not os.path.exists(metagraph_filename):
                saver.export_meta_graph(metagraph_filename)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(3):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)


            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {discriminator.input_x: x_batch, discriminator.input_y: y_batch }
                _ = sess.run(discriminator.train_op, feed)
Ejemplo n.º 32
0
def main():
    
    #  Create a parser to parse user input
    def parse_arguments():
        parser = argparse.ArgumentParser(description='Program for running several SeqGan applications.')
        parser.add_argument('app', metavar='application', type=str, choices=['obama', 'haiku', 'synth'],
                        help='Enter either \'obama\' or \'haiku\'')
        parser.add_argument('gen_n', type = int,
                        help='Number of generator pre-training steps')
        parser.add_argument('disc_n', type = int,
                        help='Number of discriminator pre-training steps')
        parser.add_argument('adv_n', type = int,
                        help='Number of adversarial pre-training steps')
        parser.add_argument('-mn', metavar="model_name", type = str, default = "",
                        help = "Name for the checkpoint files. Will be stored at ./<app>/models/<model_name>")
        parser.add_argument('-numeat', metavar="num_eat", type = int, default = 500,
                        help = "For synthetic data generation. Determines number of eaters in vocab.")
        parser.add_argument('-numfeed', metavar="num_feed", type = int, default = 500,
                        help = "For synthetic data generation. Determines number of feeders in vocab.")
        parser.add_argument('-numsent', metavar="num_sent", type = int, default = 10000,
                        help = "For synthetic data generation. Determines number of sentences generated.")
        args = parser.parse_args()

        synth_gen_params = ("NA", "NA", "NA")
        if args.app == "synth":
            synth_gen_params = (args.numsent, args.numfeed, args.numeat)
            generate_random_sents("../data/synth/input.txt", args.numsent, args.numfeed, args.numeat)

        task = load_task(args.app)

        #Make the /models directory if its not there.
        model_string = task.path +"/models/"
        if not os.path.exists("./"+model_string):
            os.mkdir("./"+model_string)
    
        #make the checkpoint directory if its not there.
        if args.mn == "":
            model_string += str(args.gen_n)+ "_" + str(args.disc_n) + "_" + str(args.adv_n)
            model_string += time.strftime("_on_%m_%d_%y", time.gmtime())
        else:
            model_string += args.mn
        if not os.path.exists("./"+model_string):
            os.mkdir("./"+model_string)
    
        return args.gen_n, args.disc_n, args.adv_n, model_string, task, synth_gen_params
    
    gen_n, disc_n, adv_n, MODEL_STRING, task, SYNTH_GEN_PARAMS = parse_arguments()


    assert START_TOKEN == 0

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

    # Initialize the data loaders
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, task.max_seq_length)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, task.max_seq_length) # For validation
    dis_data_loader = Dis_dataloader(BATCH_SIZE, task.max_seq_length)

    # Initialize the Generator
    generator = Generator(len(task.vocab), BATCH_SIZE, EMB_DIM, HIDDEN_DIM, 
                          task.max_seq_length, START_TOKEN)

    # Initialize the Discriminator
    discriminator = Discriminator(sequence_length=task.max_seq_length, 
                                  num_classes=2, 
                                  vocab_size=len(task.vocab), 
                                  embedding_size=dis_embedding_dim, 
                                  filter_sizes=dis_filter_sizes, 
                                  num_filters=dis_num_filters, 
                                  l2_reg_lambda=dis_l2_reg_lambda)

    # Set session configurations. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    saver = tf.train.Saver()
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # If restoring from a previous run ....
    if len(os.listdir("./"+MODEL_STRING)) > 0:
        saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))


    # Create batches from the positive file.
    gen_data_loader.create_batches(task.train_file)

    # Open log file for writing
    log = open(task.log_file, 'w')

    # Pre_train the generator with MLE. 
    pre_train_generator(sess, saver, MODEL_STRING, generator, gen_data_loader, 
                        likelihood_data_loader, task, log, gen_n, BATCH_SIZE,
                        task.generated_num)
    print('Start pre-training discriminator...')

    # Do the discriminator pre-training steps
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))
    train_discriminator(sess, generator, discriminator, dis_data_loader, 
                        task, log, disc_n, BATCH_SIZE, task.generated_num,
                        dis_dropout_keep_prob)
    print("Saving checkpoint ...")
    saver.save(sess, MODEL_STRING+ "/model")
    
    # Do the adversarial training steps
    rollout = ROLLOUT(generator, 0.8)
    train_adversarial(sess, saver, MODEL_STRING, generator, discriminator, 
                      rollout, dis_data_loader, likelihood_data_loader, 
                      task, log, adv_n)

    #Use the best model to generate final sample
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))
    generate_samples(sess, generator, BATCH_SIZE, task.generated_num, 
                     task.eval_file)


    #Writing results to CSV
    with open(task.eval_file) as f:
        generated = []
        for line in f:
            line = line.strip().split()
            generated.append(line)
        generated = task.vocab.decode(generated)
        f.close()

    with open(task.test_file) as f:
        references = []
        for line in f:
            line = line.strip().split()
            references.append(line)
        references = task.vocab.decode(references)  
        f.close()      

    blue = corpus_bleu([references]*len(generated), generated)
    print("Run with args {} {} {}: BLEUscore = {}\n".format(gen_n, disc_n, adv_n, blue))
    
    prop = "NA"

    if task.name == "synth":
        total_correct = 0
        for sentence in generated:
            if is_valid_phrase(sentence):
                total_correct +=1
        prop = total_correct/len(generated)
        
    if not os.path.exists("./results.csv"):
        os.mknod("./results.csv")

    with open("./results.csv", 'a') as csvfile:
        fieldnames = ["name", "task_name", "num_gen", "num_disc", "num_adv",
                    "num_sents", "num_feeders", "num_eaters", "BLEU", "prop_valid"]
        writer = csv.DictWriter(csvfile, fieldnames = fieldnames)
        writer.writeheader()
        writer.writerow({"name": MODEL_STRING, "task_name": task.name,  "num_gen": gen_n, 
                        "num_disc":disc_n, "num_adv": adv_n, "num_sents":SYNTH_GEN_PARAMS[0],
                        "num_feeders":SYNTH_GEN_PARAMS[1], "num_eaters":SYNTH_GEN_PARAMS[2],
                        "BLEU": blue, "prop_valid": prop})
        f.close()


    log.close()
Ejemplo n.º 33
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 5000
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = get_trainable_model(vocab_size)
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(vocab_size, 64, 32, 32, 20, 0, target_params)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(
            sequence_length=20,
            num_classes=2,
            vocab_size=vocab_size,
            embedding_size=dis_embedding_dim,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [param for param in tf.trainable_variables() if 'discriminator' in param.name]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss, cnn_params, aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars, global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer)

    generate_samples(sess, target_lstm, 64, 10000, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('log/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    significance_test(sess, target_lstm, likelihood_data_loader, 'significance/supervise.txt')

    print 'Start training discriminator...'
    for _ in range(dis_alter_epoch):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(
            zip(dis_x_train, dis_y_train), dis_batch_size, dis_num_epochs
        )

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except ValueError:
                pass

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss], feed_dict=feed)

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

            if test_loss < best_score:
                best_score = test_loss
                print 'best score: ', test_loss
                significance_test(sess, target_lstm, likelihood_data_loader, 'significance/seqgan.txt')

        rollout.update_params()

        # generate for discriminator
        print 'Start training discriminator'
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

    log.close()