示例#1
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                              SEQ_LENGTH, START_TOKEN,
                              target_params)  # The oracle model

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                     positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
示例#2
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # TODO: I changed this.  Why was this asserted?  Was it just to ensure the replication
    # of results?  Or is zero important otherwise?
    # Changed because 0 is a bad start token for our data.  (cannot have home label=0)
    # assert START_TOKEN == 0

    # set up logging
    log_fpath = logger.get_experiment_log_filepath()

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) # For testing
    vocab_size = VOCAB_SIZE
    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size,
                                    embedding_size=dis_embedding_dim, filter_sizes=dis_filter_sizes,
                                    num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    if not USE_GPU:
        # Prevent the environment from seeing the available GPUs (to avoid error on matlaber cluster)
        import os
        os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    gen_data_loader.create_batches(real_file)

    #  pre-train generator
    logger.write_log(log_fpath, 'pre-training generator...')
    for epoch in xrange(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            logger.write_log(log_fpath, 'generator loss:')
            logger.log_progress(log_fpath, epoch, loss)
            generate_samples(sess, generator, BATCH_SIZE, eval_generated_num, eval_file.format('pretrain'))

    logger.write_log(log_fpath, 'Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for i in range(50):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, fake_file)
        dis_data_loader.load_train_data(real_file, fake_file)
        # dis_data_loader.load_train_data(positive_file, negative_file)
        logger.write_log(log_fpath, 'epoch iterator:  %s / 50' % i)
        for j in range(3):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _d_train_output = sess.run(discriminator.train_op, feed)

    logger.write_log(log_fpath, 'finished pre-training discriminator')
    rollout = ROLLOUT(generator, 0.8)

    logger.write_log(log_fpath, 'Start Adversarial Training...')
    g_steps = 1
    d_steps = 1
    k = 10
    for batch in range(TOTAL_BATCH):
        buff = 'batch %s/%s' % (batch, TOTAL_BATCH)
        logger.write_log(log_fpath, buff)
        # Train the generator for one step
        for it in range(g_steps):
            samples = generator.generate(sess)
            rollout_num = 16  # TODO: experiment with this value
            rewards = rollout.get_reward(sess, samples, rollout_num, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if batch % 5 == 0 or batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, eval_generated_num, eval_file.format(batch))
            logger.write_log(log_fpath, 'generated some more eval samples...')

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(d_steps):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, fake_file)
            dis_data_loader.load_train_data(real_file, fake_file)

            for _ in range(k):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    logger.write_log(log_fpath, 'I\'M DONE')
示例#3
0
def main(FLAGS):
    #########################################################################################
    #  Generator  Hyper-parameters
    ######################################################################################
    EMB_DIM = FLAGS.gen_emb_dim  # 32  # embedding dimension
    HIDDEN_DIM = FLAGS.gen_hidden_dim  # 32  # hidden state dimension of lstm cell
    SEQ_LENGTH = FLAGS.seq_len  # 20  # sequence length
    START_TOKEN = 0
    PRE_EPOCH_NUM = FLAGS.gen_pretrain_epoch_num  # 120 # supervise (maximum likelihood estimation) epochs for generator
    DISC_PRE_EPOCH_NUM = FLAGS.dis_pretrain_epoch_num  # 50 # supervise (maximum likelihood estimation) epochs for descriminator
    SEED = 88
    BATCH_SIZE = FLAGS.batch_size  #64
    gen_dropout_keep_prob = FLAGS.gen_dropout_keep_prob  # 0.75
    gen_num_recurrent_layers = FLAGS.gen_num_recurrent_layers  # 1
    gen_learning_rate = FLAGS.gen_learning_rate

    #########################################################################################
    #  Discriminator  Hyper-parameters
    #########################################################################################
    dis_embedding_dim = FLAGS.dis_emb_dim  # 64
    dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
    dis_num_filters = [
        100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160
    ]
    dis_dropout_keep_prob = 0.75
    dis_l2_reg_lambda = 0.2
    dis_batch_size = FLAGS.batch_size  #64

    #########################################################################################
    #  Basic Training Parameters
    #########################################################################################
    EXPERIMENT_NAME = FLAGS.experiment_name
    TOTAL_BATCH = FLAGS.num_epochs  # 200 #num of adversarial epochs
    positive_file = 'save/real_data_%0s.txt' % EXPERIMENT_NAME
    negative_file = 'save/generator_sample_%0s.txt' % EXPERIMENT_NAME
    eval_file = "save/eval_file_%0s" % EXPERIMENT_NAME
    generated_num = 10000  # 10000

    #########################################################################################
    #  Data configurations
    #########################################################################################
    use_real_world_data = True
    real_data_file_path = FLAGS.dataset_path  # './data/text8/text8'
    dataset_name = os.path.basename(real_data_file_path)
    base_token = FLAGS.base_token  # 'char'

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    if use_real_world_data:

        real_data_train_file = real_data_file_path + '-train'
        real_data_valid_file = real_data_file_path + '-valid'
        real_data_test_file = real_data_file_path + '-test'
        real_data_dict_file = real_data_file_path + '-{}-dict.json'.format(
            base_token)

        if not os.path.exists(real_data_train_file):
            split_text8(real_data_file_path)

        map, inv_map = create_real_data_dict(real_data_train_file,
                                             real_data_dict_file, base_token)
        vocab_size = len(map)

        if dataset_name == 'text8' and base_token == 'char':
            assert vocab_size == 27  # SORRY FOR THE HARD CODING
        elif dataset_name == 'ptb' and base_token == 'word':
            assert vocab_size == 10001  # SORRY FOR THE HARD CODING
        elif dataset_name == 'toy' and base_token == 'word':
            assert vocab_size == 8  # SORRY FOR THE HARD CODING
        elif dataset_name == 'wt2' and base_token == 'word':
            assert vocab_size == 33279  # SORRY FOR THE HARD CODING
        else:
            raise TypeError

        gen_data_loader = Gen_Data_loader_text(BATCH_SIZE,
                                               map,
                                               inv_map,
                                               seq_len=SEQ_LENGTH,
                                               token_type=base_token)
        dis_data_loader = Dis_dataloader_text(BATCH_SIZE,
                                              map,
                                              inv_map,
                                              seq_len=SEQ_LENGTH,
                                              token_type=base_token)

    else:
        gen_data_loader = Gen_Data_loader(BATCH_SIZE)
        likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
        vocab_size = 5000
        dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size,
                          BATCH_SIZE,
                          EMB_DIM,
                          HIDDEN_DIM,
                          SEQ_LENGTH,
                          START_TOKEN,
                          dropout_keep_prob=gen_dropout_keep_prob,
                          num_recurrent_layers=gen_num_recurrent_layers)

    if not use_real_world_data:
        target_params = pickle.load(open('save/target_params.pkl'))
        target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                                  SEQ_LENGTH, START_TOKEN,
                                  target_params)  # The oracle model

    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.3
    sess = tf.Session(config=config)
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=999999)
    sess.run(tf.global_variables_initializer())

    if use_real_world_data:
        # gen_data_loader.create_batches(real_data_train_file)
        pass
    else:
        # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
        generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                         positive_file)
        gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        print("start epoch %0d" % epoch)

        # update learning rate
        if epoch > 5:
            gen_learning_rate /= FLAGS.gen_learning_decay * 1.

        if epoch % FLAGS.save_each_epochs == 0:
            print(
                '#########################################################################'
            )
            print('saving model...')
            save_file = os.path.join(
                '.', 'ckp', EXPERIMENT_NAME + '_pretrain_epoch_%0d' % epoch,
                EXPERIMENT_NAME + '_pretrain_epoch_%0d' % epoch)
            saver.save(sess, save_file)

        if use_real_world_data:
            gen_data_loader.create_batches(real_data_train_file,
                                           limit_num_samples=generated_num)

        loss = pre_train_epoch(sess, generator, gen_data_loader,
                               gen_learning_rate)
        if epoch % 1 == 0:
            if use_real_world_data:
                generate_real_data_samples(
                    sess, generator, BATCH_SIZE, generated_num,
                    eval_file + "_epoch_%0d.txt" % epoch, inv_map, base_token)
                test_loss = 0  # FIXME - TEMP
            else:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 eval_file)
                likelihood_data_loader.create_batches(eval_file)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)

            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for epoch in range(DISC_PRE_EPOCH_NUM):
        print("start epoch %0d" % epoch)
        if use_real_world_data:
            generate_real_data_samples(sess, generator, BATCH_SIZE,
                                       generated_num, negative_file, inv_map,
                                       base_token)
            dis_data_loader.load_train_data(real_data_train_file,
                                            negative_file)
        else:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        print("start epoch %0d" % total_batch)

        if total_batch % FLAGS.save_each_epochs == 0:
            print(
                '#########################################################################'
            )
            print('saving model...')
            save_file = os.path.join(
                '.', 'ckp', EXPERIMENT_NAME + '_epoch_%0d' % total_batch,
                EXPERIMENT_NAME + '_epoch_%0d' % total_batch)
            saver.save(sess, save_file)

        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {
                generator.x: samples,
                generator.rewards: rewards,
                generator.learning_rate: 0.01
            }
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            if not use_real_world_data:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 eval_file)
                likelihood_data_loader.create_batches(eval_file)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)
                buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                    test_loss) + '\n'
                print('total_batch: ', total_batch, 'test_loss: ', test_loss)
                log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):

            if use_real_world_data:
                generate_real_data_samples(sess, generator, BATCH_SIZE,
                                           generated_num, negative_file,
                                           inv_map, base_token)
                dis_data_loader.load_train_data(real_data_train_file,
                                                negative_file)
            else:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    print(
        '#########################################################################'
    )
    print('saving model...')
    save_file = os.path.join('.', 'ckp', EXPERIMENT_NAME, EXPERIMENT_NAME)
    saver.save(sess, save_file)

    #
    # print '#########################################################################'
    # print 'Start Language Model Evaluation...'
    # test_data_loader = Gen_Data_loader_text(BATCH_SIZE,map,inv_map)
    # test_data_loader.create_batches(real_data_test_file)
    # language_model_evaluation(sess,generator, test_data_loader)

    log.close()
示例#4
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=num_classes, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                                filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)
    #target_params = cPickle.load(open('save/target_params.pkl'))
    #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    # avoid occupy all the memory if the GPU
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    #Savers 
    saver_gen = tf.train.Saver()
    saver_dis = tf.train.Saver()
    saver_seqgan = tf.train.Saver()

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    gen_data_loader.create_batches(positive_file) #把data load進來

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training Generator...') #MLE
    log.write('pre-training generator...\n') 
    for epoch in range(PRE_GEN_EPOCH_NUM):
        s = time.time()
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        
        # detect best model
        best = 1000
        if loss < best:
            saver_gen.save(sess,"model/pretrain_gen_best")

        if epoch % 5 == 0:
            print('pre-train epoch: ', epoch, 'loss: ', loss, "time: ", time.time()-s)
            log.write('epoch:\t'+ str(epoch) + '\tloss:\t' + str(loss) + '\n')

    # pre-train discriminator
    print('Start pre-training discriminator...')
    log.write('pre-training discriminator...\n') 
    
    # Train 3 epoch on the generated data and do this for 50 times
    for epoch in range(PRE_DIS_EPOCH_NUM):
        s = time.time()
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _,acc = sess.run([discriminator.train_op,discriminator.accuracy], feed)

        best = 0
        if acc > best:
            saver_dis.save(sess, "./model/pretrain_dis_best")
            best = acc

        print("pre-train epoch: ", epoch, " acc: ", acc," time: ", time.time()-s)
        log.write("epoch:\t" + str(epoch) + "\tacc:\t" + str(acc) + "\n")

    rollout = ROLLOUT(generator, 0.8)

    print( '#########################################################################')
    print( 'Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        s = time.time()

        for it in range(ADV_GEN_TIME):
            samples = generator.generate(sess) # 一條seq
            rewards = rollout.get_reward(sess, samples, 16, discriminator) #MC search
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed) # do policy gradient

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1: # cal NLL
            avg = np.mean(np.sum(rewards, axis=1), axis=0) / SEQ_LENGTH
         
            #print('total_batch: ', total_batch, 'average reward: ', avg)
            log.write('epoch:\t' + str(total_batch) + '\treward:\t' + str(avg) + '\n')

            saver_seqgan.save(sess, "./model/seq_gan", global_step=total_batch)

        # Update roll-out parameters
        rollout.update_params() # train G

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

        print('epoch: ', total_batch, 'average reward: ', avg," time: ",time.time()-s)
    
    log.close() 

    # generate examples
    print("Training Finished, starting to generating test")
    generate_samples(sess, generator, BATCH_SIZE, test_num,generate_file)
    
    print("Finish")
示例#5
0
文件: train.py 项目: mgetech/SubLoc
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()
    np.random.seed(config_train.seed)
    assert config_train.start_token == 0

    #Build dataloader for generaotr, testing and discriminator
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    #Build generator and its rollout
    generator = Generator(config=config_gen)
    generator.build()
    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    #target_params = cPickle.load(open('save/target_params.pkl'))
    #print(target_params)
    #target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model

    #Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    #Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(
        config_train.gen_learning_rate)
    var_pretrained = [
        v for v in tf.trainable_variables() if 'teller' in v.name
    ]  #Using name 'teller' here to prevent name collision of target LSTM
    gradients, variables = zip(*pretrained_optimizer.compute_gradients(
        generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_upate = pretrained_optimizer.apply_gradients(
        zip(gradients, variables))

    #Initialize all variables
    sess = tf.Session(config=config_hardware)
    sess.run(tf.global_variables_initializer())

    #Initalize data loader of generator
    #generate_samples(sess, target_lstm, config_train.batch_size, config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)
    '''
    #Start pretraining
    log = open('save/experiment-log.txt', 'w')
    print 'Start pre-training generator...'
    log.write('pre-training...\n')
    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\
                                                                                       generator.input_seqs_mask:np.ones_like(batch)})
        
        if epoch % config_train.test_per_epoch == 0:
            #generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            #test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            #print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            print 'pre-train epoch ', epoch
            #buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' 
            log.write(buffer)
        
    '''
    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        #generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,
                                        config_train.negative_file)
        for _ in range(config_train.dis_update_epoch_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x:
                    x_batch,
                    discriminator.input_y:
                    y_batch,
                    discriminator.dropout_keep_prob:
                    config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    #Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(
        generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    #Initialize global variables of optimizer for adversarial training
    uninitialized_var = [
        e for e in tf.global_variables() if e not in tf.trainable_variables()
    ]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    #Start adversarial training
    print('Start adversarial training')
    for total_batch in range(config_train.total_batch):
        print(f"total_batch: {total_batch}")

        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshape)
            #samples = pre(samples,0.8)
            #print(len(samples))
            feed = {"pred_seq_rollout:0": samples}
            reward_rollout = []
            #calcuate the reward given in the specific stpe t by roll out
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,
                                        feed_dict=feed)
                rollout_list_stack = np.vstack(
                    rollout_list
                )  #shape: #batch_size * #rollout_step, #sequence length
                reward_rollout_seq = sess.run(
                    discriminator.ypred_for_auc,
                    feed_dict={
                        discriminator.input_x: rollout_list_stack,
                        discriminator.dropout_keep_prob: 1.0
                    })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,
                                           feed_dict={
                                               discriminator.input_x: samples,
                                               discriminator.dropout_keep_prob:
                                               1.0
                                           })
                reward_allseq = np.concatenate(
                    (reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(
                        r,
                        config_gen.gen_batch_size * config_gen.sequence_length,
                        config_gen.gen_batch_size)])
                reward_rollout.append(np.array(reward_tmp))
            rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num
            _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\
                                                                                        generator.rewards:rewards})

        for _ in range(config_train.dis_update_time_adv):
            my_gen(sess, generator, config_train.batch_size,
                   config_train.generated_num, config_train.negative_feedback)
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.negative_feedback,
                                            config_train.negative_file)
            for _ in range(config_train.dis_update_epoch_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:
                        x_batch,
                        discriminator.input_y:
                        y_batch,
                        discriminator.dropout_keep_prob:
                        config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
示例#6
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # data loaders declaration
    # loaders for generator, discriminator, and additional validation data loader
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    eval_data_loader = Gen_Data_loader(BATCH_SIZE)

    # define generator and discriminator
    # general structures are same with the original model
    # learning rates for generator needs heavy tuning for general use
    # l2 reg for D & G also affects performance
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, GENERATOR_LR, REWARD_GAMMA)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
                                filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    # VRAM limitation for efficient deployment
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)
    sess.run(tf.global_variables_initializer())
    # define saver
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=1)
    # generate real data from the true dataset
    gen_data_loader.create_batches(positive_file)
    # generate real validation data from true validation dataset
    eval_data_loader.create_batches(valid_file)

    time = str(datetime.datetime.now())[:-7] #####################这一句在V1版本里,是没有的,就是给文本加一个时间
    log = open('save/experiment-log' + str(time) + '.txt', 'w')
    log.write(str(config)+'\n')################ 在LOG上写一些内容
    log.write('D loss: original\n')############  
    log.flush()################################

    #summary_writer = tf.summary.FileWriter('save/tensorboard/', graph=tf.get_default_graph())

    if config['pretrain'] == True:
        #  pre-train generator
        print 'Start pre-training...'
        log.write('pre-training...\n')
        for epoch in xrange(PRE_GEN_EPOCH): ########################################生成器的预训练
            # calculate the loss by running an epoch
            loss = pre_train_epoch(sess, generator, gen_data_loader)##############这个pre_train_epoch在哪里???

            # measure bleu score with the validation set
            bleu_score = calculate_bleu(sess, generator, eval_data_loader)

            # since the real data is the true data distribution, only evaluate the pretraining loss
            # note the absence of the oracle model which is meaningless for general use
            buffer = 'pre-train epoch: ' + str(epoch) + ' pretrain_loss: ' + str(loss) + ' bleu: ' + str(bleu_score)
            print(buffer)
            log.write(buffer + '\n')
            log.flush()

            # generate 5 test samples per epoch
            # it automatically samples from the generator and postprocess to midi file
            # midi files are saved to the pre-defined folder
            if epoch == 0: ##开始
                generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
                POST.main(negative_file, 5, str(-1)+'_vanilla_', 'midi')
            elif epoch == PRE_GEN_EPOCH - 1: ##结束
                generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
                POST.main(negative_file, 5, str(-PRE_GEN_EPOCH)+'_vanilla_', 'midi')


        print 'Start pre-training discriminator...'
        # Train 3 epoch on the generated data and do this for 50 times
        # this trick is also in spirit of the original work, but the epoch strategy needs tuning
        for epochs in range(PRE_DIS_EPOCH):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            D_loss = 0
            for _ in range(3):
                dis_data_loader.load_train_data(positive_file, negative_file)
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
                    D_loss += discriminator.loss.eval(feed, session=sess)
            buffer = 'epoch: ' + str(epochs+1) + '  D loss: ' + str(D_loss/dis_data_loader.num_batch/3)
            print(buffer)
            log.write(buffer + '\n')
            log.flush()

        # save the pre-trained checkpoint for future use
        # if one wants adv. training only, comment out the pre-training section after the save
        save_checkpoint(sess, saver,PRE_GEN_EPOCH, PRE_DIS_EPOCH)

    # define rollout target object
    # the second parameter specifies target update rate
    # the higher rate makes rollout "conservative", with less update from the learned generator
    # we found that higher update rate stabilized learning, constraining divergence of the generator
    rollout = ROLLOUT(generator, ROLLOUT_UPDATE_RATE)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    if config['pretrain'] == False:
        # load checkpoint of pre-trained model
        load_checkpoint(sess, saver)

    # 0.001 to 0.01     这个也是这个版本独有的!!!
    if config['x10adv_g'] == True:
        generator.learning_rate *= 10

    for total_batch in range(TOTAL_BATCH):
        G_loss = 0
        # Train the generator for one step
        for it in range(epochs_generator):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, config['rollout_num'], discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)
            G_loss += generator.g_loss.eval(feed, session=sess)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        D_loss = 0
        for _ in range(epochs_discriminator):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            for _ in range(config['epochs_discriminator_multiplier']):
                dis_data_loader.load_train_data(positive_file, negative_file)
                dis_data_loader.reset_pointer()

                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
                    D_loss += discriminator.loss.eval(feed, session=sess)

        # measure stability and performance evaluation with bleu score
        bleu_score = calculate_bleu(sess, generator, eval_data_loader)
        buffer = 'epoch: ' + str(total_batch+1) + \
                 ',  G_adv_loss: %.12f' % (G_loss/epochs_generator) + \
                 ',  D loss: %.12f' % (D_loss/epochs_discriminator/config['epochs_discriminator_multiplier']) + \
                 ',  bleu score: %.12f' % bleu_score
        print(buffer)
        log.write(buffer + '\n')
        log.flush()

        ####################################################### 这段也是新出的
        if config['infinite_loop'] is True:
            if bleu_score < config['loop_threshold']:
                buffer = 'Mode collapse detected, restarting from pretrained model...'
                print(buffer)
                log.write(buffer + '\n')
                log.flush()
                load_checkpoint(sess, saver)

        # generate random test samples and postprocess the sequence to midi file
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
        POST.main(negative_file, 5, str(total_batch)+'_vanilla_', 'midi')
    log.close()
示例#7
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # data loaders declaration
    # loaders for generator, discriminator, and additional validation data loader
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    eval_data_loader = Gen_Data_loader(BATCH_SIZE)

    # define generator and discriminator
    # general structures are same with the original model
    # learning rates for generator needs heavy tuning for general use
    # l2 reg for D & G also affects performance
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    # VRAM limitation for efficient deployment
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)
    sess.run(tf.global_variables_initializer())
    # define saver
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=1)
    # generate real data from the true dataset
    gen_data_loader.create_batches(positive_file)
    # generate real validation data from true validation dataset
    eval_data_loader.create_batches(valid_file)

    log = open('save/experiment-log.txt', 'w')
    if config['pretrain'] == True:
        #  pre-train generator
        print 'Start pre-training...'
        log.write('pre-training...\n')
        for epoch in xrange(PRE_GEN_EPOCH):
            # calculate the loss by running an epoch
            loss = pre_train_epoch(sess, generator, gen_data_loader)

            # measure bleu score with the validation set
            bleu_score = calculate_bleu(sess, generator, eval_data_loader)
            # since the real data is the true data distribution, only evaluate the pretraining loss
            # note the absence of the oracle model which is meaningless for general use
            buffer = 'pre-train epoch: ' + str(
                epoch) + ' pretrain_loss: ' + str(loss) + ' bleu: ' + str(
                    bleu_score)
            print(buffer)
            log.write(buffer)

            # generate 5 test samples per epoch
            # it automatically samples from the generator and postprocess to midi file
            # midi files are saved to the pre-defined folder
            if epoch == 0:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                POST.main(negative_file, 5, -1)
            elif epoch == PRE_GEN_EPOCH - 1:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                POST.main(negative_file, 5, -PRE_GEN_EPOCH)

        print 'Start pre-training discriminator...'
        # Train 3 epoch on the generated data and do this for 50 times
        # this trick is also in spirit of the original work, but the epoch strategy needs tuning
        for epochs in range(PRE_DIS_EPOCH):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            D_loss = 0
            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
                    D_loss += discriminator.loss.eval(feed, session=sess)
            buffer = 'epoch: ' + str(epochs + 1) + '  D loss: ' + str(
                D_loss / dis_data_loader.num_batch / 3)
            print(buffer)
            log.write(buffer)

        # save the pre-trained checkpoint for future use
        # if one wants adv. training only, comment out the pre-training section after the save
        save_checkpoint(sess, saver, PRE_GEN_EPOCH, PRE_DIS_EPOCH)

    # define rollout target object
    # the second parameter specifies target update rate
    # the higher rate makes rollout "conservative", with less update from the learned generator
    # we found that higher update rate stabilized learning, constraining divergence of the generator
    rollout = ROLLOUT(generator, 0.9)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    if config['pretrain'] == False:
        # load checkpoint of pre-trained model
        load_checkpoint(sess, saver)
    for total_batch in range(TOTAL_BATCH):
        G_loss = 0
        # Train the generator for one step
        for it in range(epochs_generator):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, config['rollout_num'],
                                         discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)
            G_loss += generator.g_loss.eval(feed, session=sess)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        D_loss = 0
        for _ in range(epochs_discriminator):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            for _ in range(3):
                dis_data_loader.reset_pointer()

                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
                    D_loss += discriminator.loss.eval(feed, session=sess)

        # measure stability and performance evaluation with bleu score
        buffer = 'epoch: ' + str(total_batch+1) + \
                 ',  G_adv_loss: %.12f' % (G_loss/epochs_generator) + \
                 ',  D loss: %.12f' % (D_loss/epochs_discriminator/3) + \
                 ',  bleu score: %.12f' % calculate_bleu(sess, generator, eval_data_loader)
        print(buffer)
        log.write(buffer)

        # generate random test samples and postprocess the sequence to midi file
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file + "_EP_" + str(total_batch))
        POST.main(negative_file + "_EP_" + str(total_batch), 5, total_batch)
    log.close()
示例#8
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # assert START_TOKEN == 0

    # gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    # likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    vocab_size = 5001
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    # generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    # target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    # target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    # generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    # gen_data_loader.create_batches(positive_file)

    # log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    # print ('Start pre-training...')
    # log.write('pre-training...\n')
    # for epoch in range(PRE_EPOCH_NUM):
    #     loss = pre_train_epoch(sess, generator, gen_data_loader)
    #     if epoch % 5 == 0:
    #         generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    #         likelihood_data_loader.create_batches(eval_file)
    #         test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    # print ('pre-train epoch ', epoch, 'test_loss ', test_loss)
    # buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
    #         log.write(buffer)

    # print ('Start training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    # for _ in range(50):
    #     generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
    print("load data from file")
    dis_data_loader.load_train_data(positive_file, negative_file)
    print("load success!")
    for ep in range(3):
        print("Epoch %d : @@@@@@@" % ep)
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch):
            x_batch, y_batch = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.input_y: y_batch,
                discriminator.dropout_keep_prob: dis_dropout_keep_prob
            }
            loss, _ = sess.run([discriminator.loss, discriminator.train_op],
                               feed)
            if it % 100 == 0:
                print('steps ', it, 'loss ', loss)
示例#9
0
def main():
    
    #  Create a parser to parse user input
    def parse_arguments():
        parser = argparse.ArgumentParser(description='Program for running several SeqGan applications.')
        parser.add_argument('app', metavar='application', type=str, choices=['obama', 'haiku', 'synth'],
                        help='Enter either \'obama\' or \'haiku\'')
        parser.add_argument('gen_n', type = int,
                        help='Number of generator pre-training steps')
        parser.add_argument('disc_n', type = int,
                        help='Number of discriminator pre-training steps')
        parser.add_argument('adv_n', type = int,
                        help='Number of adversarial pre-training steps')
        parser.add_argument('-mn', metavar="model_name", type = str, default = "",
                        help = "Name for the checkpoint files. Will be stored at ./<app>/models/<model_name>")
        parser.add_argument('-numeat', metavar="num_eat", type = int, default = 500,
                        help = "For synthetic data generation. Determines number of eaters in vocab.")
        parser.add_argument('-numfeed', metavar="num_feed", type = int, default = 500,
                        help = "For synthetic data generation. Determines number of feeders in vocab.")
        parser.add_argument('-numsent', metavar="num_sent", type = int, default = 10000,
                        help = "For synthetic data generation. Determines number of sentences generated.")
        args = parser.parse_args()

        synth_gen_params = ("NA", "NA", "NA")
        if args.app == "synth":
            synth_gen_params = (args.numsent, args.numfeed, args.numeat)
            generate_random_sents("../data/synth/input.txt", args.numsent, args.numfeed, args.numeat)

        task = load_task(args.app)

        #Make the /models directory if its not there.
        model_string = task.path +"/models/"
        if not os.path.exists("./"+model_string):
            os.mkdir("./"+model_string)
    
        #make the checkpoint directory if its not there.
        if args.mn == "":
            model_string += str(args.gen_n)+ "_" + str(args.disc_n) + "_" + str(args.adv_n)
            model_string += time.strftime("_on_%m_%d_%y", time.gmtime())
            adder, i = "", 0
            while os.path.exists("./"+model_string+adder):
                adder = "_"+str(i)
                i +=1
            modelstring += adder
        else:
            model_string += args.mn
        if not os.path.exists("./"+model_string):
            os.mkdir("./"+model_string)
    
        return args.gen_n, args.disc_n, args.adv_n, model_string, task, synth_gen_params
    
    gen_n, disc_n, adv_n, MODEL_STRING, task, SYNTH_GEN_PARAMS = parse_arguments()


    assert START_TOKEN == 0

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

    # Initialize the data loaders
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, task.max_seq_length)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, task.max_seq_length) # For validation
    dis_data_loader = Dis_dataloader(BATCH_SIZE, task.max_seq_length)

    # Initialize the Generator
    generator = Generator(len(task.vocab), BATCH_SIZE, EMB_DIM, HIDDEN_DIM, 
                          task.max_seq_length, START_TOKEN)

    # Initialize the Discriminator
    discriminator = Discriminator(sequence_length=task.max_seq_length, 
                                  num_classes=2, 
                                  vocab_size=len(task.vocab), 
                                  embedding_size=dis_embedding_dim, 
                                  filter_sizes=dis_filter_sizes, 
                                  num_filters=dis_num_filters, 
                                  l2_reg_lambda=dis_l2_reg_lambda)

    # Set session configurations. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    saver = tf.train.Saver()
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # If restoring from a previous run ....
    if len(os.listdir("./"+MODEL_STRING)) > 0:
        saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))


    # Create batches from the positive file.
    gen_data_loader.create_batches(task.train_file)

    # Open log file for writing
    log = open(task.log_file, 'w')

    # Pre_train the generator with MLE. 
    pre_train_generator(sess, saver, MODEL_STRING, generator, gen_data_loader, 
                        likelihood_data_loader, task, log, gen_n, BATCH_SIZE,
                        task.generated_num)
    print('Start pre-training discriminator...')

    # Do the discriminator pre-training steps
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))
    train_discriminator(sess, generator, discriminator, dis_data_loader, 
                        task, log, disc_n, BATCH_SIZE, task.generated_num,
                        dis_dropout_keep_prob)
    print("Saving checkpoint ...")
    saver.save(sess, MODEL_STRING+ "/model")
    
    # Do the adversarial training steps
    rollout = ROLLOUT(generator, 0.8)
    train_adversarial(sess, saver, MODEL_STRING, generator, discriminator, 
                      rollout, dis_data_loader, likelihood_data_loader, 
                      task, log, adv_n)

    #Use the best model to generate final sample
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))
    generate_samples(sess, generator, BATCH_SIZE, task.generated_num, 
                     task.eval_file)


    #Writing results to CSV
    with open(task.eval_file) as f:
        generated = []
        for line in f:
            line = line.strip().split()
            generated.append(line)
        generated = task.vocab.decode(generated)
        f.close()

    with open(task.test_file) as f:
        references = []
        for line in f:
            line = line.strip().split()
            references.append(line)
        references = task.vocab.decode(references)  
        f.close()      

        
    if not os.path.exists("./results.csv"):
        os.mknod("./results.csv")

    with open("./results.csv", 'a') as csvfile:
        fieldnames = ["name", "task_name", "num_gen", "num_disc", "num_adv",
                    "num_sents", "num_feeders", "num_eaters", "BLEU", "prop_valid"]
        writer = csv.DictWriter(csvfile, fieldnames = fieldnames)
        csvfile.seek(0, os.SEEK_END) # go to end of file
        if not csvfile.tell(): # if current position is != 0)
            writer.writeheader()

        blue = corpus_bleu([references]*len(generated), generated)
        print("Run with args {} {} {}: BLEUscore = {}\n".format(gen_n, disc_n, adv_n, blue))
        
        prop = "NA"

        if task.name == "synth":
            total_correct = 0
            for sentence in generated:
                if is_valid_phrase(sentence):
                    total_correct +=1
            prop = total_correct/len(generated)

        writer.writerow({"name": MODEL_STRING, "task_name": task.name,  "num_gen": gen_n, 
                        "num_disc":disc_n, "num_adv": adv_n, "num_sents":SYNTH_GEN_PARAMS[0],
                        "num_feeders":SYNTH_GEN_PARAMS[1], "num_eaters":SYNTH_GEN_PARAMS[2],
                        "BLEU": blue, "prop_valid": prop})
        f.close()


    log.close()
示例#10
0
文件: Main.py 项目: NemoNone/LeakGAN
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE,FLAGS.length)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE,FLAGS.length) # For testing
    vocab_size = 5000
    file = open('save/target_params.pkl', 'rb')
    target_params = cPickle.load(file)
    
    dis_data_loader = Dis_dataloader(BATCH_SIZE,SEQ_LENGTH)
    discriminator = Discriminator(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,emb_dim=EMB_DIM,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,goal_size=GOAL_SIZE,step_size=4,D_model=discriminator,
                      learning_rate=LEARNING_RATE)
    if SEQ_LENGTH == 40:
        target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)  # The oracle model
    else:
        target_lstm = TARGET_LSTM20(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN,target_params)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file, 0)
    for a in range(1):
        g = sess.run(leakgan.gen_x,feed_dict={leakgan.drop_out:0.8,leakgan.train:1})
        print(g)

        print("epoch:",a,"  ")

    log = open('save/experiment-log.txt', 'w')
    gen_data_loader.create_batches(positive_file)
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables)
    model = tf.train.latest_checkpoint(model_path)
    print(model)
    if FLAGS.restore and model:
        # model = tf.train.latest_checkpoint(model_path)
        # if model and FLAGS.restore:
        if model_path+'/' + FLAGS.model:
            print(model_path+'/' + FLAGS.model)
            saver.restore(sess, model_path+'/' + FLAGS.model)
        else:
            saver.restore(sess, model)
    else:
        if FLAGS.resD and model_path + '/' + FLAGS.model:
                print(model_path + '/' + FLAGS.model)
                saver.restore(sess, model_path + '/' + FLAGS.model)

                print('Start pre-training...')
                log.write('pre-training...\n')
                for epoch in range(PRE_EPOCH_NUM):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file, 0)
                        likelihood_data_loader.create_batches(eval_file)
                        test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                        print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                        buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
                        log.write(buffer)
                        generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
                        likelihood_data_loader.create_batches(eval_file)
                        test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                        print("Groud-Truth:", test_loss)
                saver.save(sess, model_path + '/leakgan_pre')
        else:
                print('Start pre-training discriminator...')
                # Train 3 epoch on the generated data and do this for 50 times
                for i in range(10):
                    for _ in range(5):
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
                        generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file,0)
                        # gen_data_loader.create_batches(positive_file)
                        dis_data_loader.load_train_data(positive_file, negative_file)
                        for _ in range(3):
                            dis_data_loader.reset_pointer()
                            for it in range(dis_data_loader.num_batch):
                                x_batch, y_batch = dis_data_loader.next_batch()
                                feed = {
                                    discriminator.D_input_x: x_batch,
                                    discriminator.D_input_y: y_batch,
                                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                                }
                                D_loss,_ = sess.run([discriminator.D_loss,discriminator.D_train_op], feed)
                                # # print 'D_loss ', D_loss
                                # buffer =  str(D_loss) + '\n'
                                # log.write(buffer)
                        leakgan.update_feature_function(discriminator)
                    saver.save(sess, model_path + '/leakgan_preD')

            # saver.save(sess, model_path + '/leakgan')
        #  pre-train generator
                    print('Start pre-training...')
                    log.write('pre-training...\n')
                    for epoch in range(PRE_EPOCH_NUM/10):
                        loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                        if epoch % 5 == 0:
                            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file,0)
                            likelihood_data_loader.create_batches(eval_file)
                            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                            buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
                            log.write(buffer)
                            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
                            likelihood_data_loader.create_batches(eval_file)
                            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                            print("Groud-Truth:", test_loss)
                saver.save(sess, model_path + '/leakgan_pre')

    gencircle = 1
    #
    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):

            for gi in range(gencircle):
                samples = leakgan.generate(sess,1.0,1)
                rewards = get_reward(leakgan, discriminator,sess, samples, 4, dis_dropout_keep_prob)
                feed = {leakgan.x: samples, leakgan.reward: rewards,leakgan.drop_out:1.0}
                _,_,g_loss,w_loss = sess.run([leakgan.manager_updates,leakgan.worker_updates,leakgan.goal_loss,leakgan.worker_loss], feed_dict=feed)
                print('total_batch: ', total_batch, "  ",g_loss,"  ", w_loss)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file,0)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)
            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print("Groud-Truth:" ,test_loss)

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file,0)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
                    # print 'D_loss ', D_loss
            leakgan.update_feature_function(discriminator)
    log.close()
def main():
    # load rhyme table
    table = np.load("./data/table.npy")
    np.random.seed(SEED)
    random.seed(SEED)

    # data loader
    # gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    input_data_loader = Input_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    D = Discriminator(SEQ_LENGTH, num_class, vocab_size, dis_emb_size,
                      dis_filter_sizes, dis_num_filters, 0.2)
    G = Generator(vocab_size,
                  BATCH_SIZE,
                  EMB_DIM,
                  HIDDEN_DIM,
                  SEQ_LENGTH,
                  START_TOKEN,
                  table,
                  has_input=True)

    # avoid occupy all the memory of the GPU
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    # savers for different models
    saver_gen = tf.train.Saver()
    saver_dis = tf.train.Saver()
    saver_seqgan = tf.train.Saver()

    # gen_data_loader.create_batches(positive_file)
    input_data_loader.create_batches(x_file, y_file)
    log = open('./experiment-log.txt', 'w')
    #  pre-train generator
    if pre_train_gen_path:
        print("loading pretrain generator model...")
        log.write("loading pretrain generator model...")
        restore_model(G, sess, saver_gen, pre_train_gen_path)
        print("loaded")
    else:
        log.write('pre-training generator...\n')
        print('Start pre-training...')
        for epoch in range(PRE_GEN_NUM):
            s = time.time()
            # loss = pre_train_epoch(sess, G, gen_data_loader)
            loss = pre_train_epoch(sess, G, input_data_loader)
            print("Epoch ", epoch, " loss: ", loss)
            log.write("Epoch:\t" + str(epoch) + "\tloss:\t" + str(loss) + "\n")
            print("pre-train generator epoch time: ", time.time() - s, " s")
            best = 1000
            if loss < best:
                saver_gen.save(sess, "./model/pre_gen/pretrain_gen_best")
                best = loss
    dev_loader = Input_Data_loader(BATCH_SIZE)
    dev_loader.create_batches(dev_x, dev_y)

    if pre_train_dis_path:
        print("loading pretrain discriminator model...")
        log.write("loading pretrain discriminator model...")
        restore_model(D, sess, saver_dis, pre_train_dis_path)
        print("loaded")
    else:
        log.write('pre-training discriminator...\n')
        print("Start pre-train the discriminator")
        s = time.time()
        for epoch in range(PRE_DIS_NUM):
            # generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file)
            generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file,
                             input_data_loader)
            # dis_data_loader.load_train_data(positive_file, negative_file)
            dis_data_loader.load_train_data(y_file, negative_file)
            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        D.input_x: x_batch,
                        D.input_y: y_batch,
                        D.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, acc = sess.run([D.train_op, D.accuracy], feed)
            print("Epoch ", epoch, " Accuracy: ", acc)
            log.write("Epoch:\t" + str(epoch) + "\tAccuracy:\t" + str(acc) +
                      "\n")
            best = 0
            # if epoch % 20  == 0 or epoch == PRE_DIS_NUM -1:
            #     print("saving at epoch: ", epoch)
            #     saver_dis.save(sess, "./model/per_dis/pretrain_dis", global_step=epoch)
            if acc > best:
                saver_dis.save(sess, "./model/pre_dis/pretrain_dis_best")
                best = acc
        print("pre-train discriminator: ", time.time() - s, " s")

    g_beta = G_beta(G, update_rate=0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('Start adversarial training...\n')

    for total_batch in range(TOTAL_BATCH):
        s = time.time()
        for it in range(ADV_GEN_TIME):
            for i in range(input_data_loader.num_batch):
                input_x, target = input_data_loader.next_batch()
                samples = G.generate(sess, input_x)
                rewards = g_beta.get_reward(sess, target, input_x, sample_time,
                                            D)
                avg = np.mean(np.sum(rewards, axis=1), axis=0) / SEQ_LENGTH
                print(" epoch : %d time : %di: %d avg %f" %
                      (total_batch, it, i, avg))
                feed = {G.x: samples, G.rewards: rewards, G.inputs: input_x}
                _ = sess.run(G.g_update, feed_dict=feed)
        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            avg = np.mean(np.sum(rewards, axis=1), axis=0) / SEQ_LENGTH
            buffer = 'epoch:\t' + str(total_batch) + '\treward:\t' + str(
                avg) + '\n'
            print('total_batch: ', total_batch, 'average reward: ', avg)
            log.write(buffer)

            saver_seqgan.save(sess,
                              "./model/seq_gan/seq_gan",
                              global_step=total_batch)

        g_beta.update_params()

        # train the discriminator
        for it in range(ADV_GEN_TIME // GEN_VS_DIS_TIME):
            # generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file)
            generate_samples(sess, G, BATCH_SIZE, generated_num, negative_file,
                             input_data_loader)
            dis_data_loader.load_train_data(y_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for batch in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        D.input_x: x_batch,
                        D.input_y: y_batch,
                        D.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(D.train_op, feed_dict=feed)
        print("Adversarial Epoch consumed: ", time.time() - s, " s")

    # final generation
    print("Finished")
    log.close()
    # save model

    print("Training Finished, starting to generating test ")
    test_loader = Input_Data_loader(batch_size=BATCH_SIZE)
    test_loader.create_batches(test_x, test_y)

    generate_samples(sess, G, BATCH_SIZE, test_num, test_file + "_final.txt",
                     test_loader)
def main(FLAGS):

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    if use_real_world_data:
        vocab_size = 27
        # split to train-valid-test
        real_data_train_file = real_data_file_path + '-train'
        real_data_valid_file = real_data_file_path + '-valid'
        real_data_test_file = real_data_file_path + '-test'
        real_data_dict_file = real_data_file_path + '-dict.json'
        if not os.path.exists(real_data_train_file):
            split_text8(real_data_file_path)
        charmap, inv_charmap = create_real_data_dict(real_data_train_file,
                                                     real_data_dict_file)
        # gen_data_loader = Gen_Data_loader_text8(BATCH_SIZE,charmap,inv_charmap,SEQ_LENGTH)
        # dis_data_loader = Dis_dataloader_text8(BATCH_SIZE,charmap,inv_charmap,SEQ_LENGTH)
    else:
        gen_data_loader = Gen_Data_loader(BATCH_SIZE)
        likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
        vocab_size = 5000
        dis_data_loader = Dis_dataloader(BATCH_SIZE)

    experiments_list = [
        exp for exp in os.listdir('ckp')
        if os.path.isdir(os.path.join('ckp', exp))
    ]
    if FLAGS.epoch_exp:
        experiments_list.sort(key=lambda x: int(x.split('_epoch_')[-1]))
        stats = np.zeros([2, len(experiments_list)], dtype=np.float32)

    for i, exp_name in enumerate(experiments_list):
        print(
            '#########################################################################'
        )
        print('loading model [%0s]...' % exp_name)

        # restore generator arch
        try:
            if FLAGS.epoch_exp:
                config = os.path.join(
                    'ckp', 'config_' + exp_name.split('_epoch_')[0] + '.txt')
            else:
                config = os.path.join('ckp', 'config_' + exp_name + '.txt')
            EMB_DIM = restore_param_from_config(config, param='gen_emb_dim')
            HIDDEN_DIM = restore_param_from_config(config,
                                                   param='gen_hidden_dim')
            dis_embedding_dim = restore_param_from_config(config,
                                                          param='dis_emb_dim')
            seq_len = restore_param_from_config(config, param='seq_len')
        except:
            print("ERROR: CONFIG FILE WAS NOT FOUND - skipping model")
            continue
        assert type(EMB_DIM) == int
        assert type(HIDDEN_DIM) == int
        assert type(dis_embedding_dim) == int
        assert type(seq_len) == int

        if seq_len == 20:
            dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
            dis_num_filters = [
                100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160
            ]
            LEARNING_RATE = 0.0015
        elif seq_len == 40:
            dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 30, 40]
            dis_num_filters = [
                100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160, 160
            ]
            LEARNING_RATE = 0.0005
        else:
            exit(0)
        print(SEQ_LENGTH)

        GOAL_OUT_SIZE = sum(dis_num_filters)
        GOAL_SIZE = 16

        tf.reset_default_graph()
        # generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
        discriminator = Discriminator(SEQ_LENGTH,
                                      num_classes=2,
                                      vocab_size=vocab_size,
                                      dis_emb_dim=dis_embedding_dim,
                                      filter_sizes=dis_filter_sizes,
                                      num_filters=dis_num_filters,
                                      batch_size=BATCH_SIZE,
                                      hidden_dim=HIDDEN_DIM,
                                      start_token=START_TOKEN,
                                      goal_out_size=GOAL_OUT_SIZE,
                                      step_size=4)
        generator = LeakGAN(SEQ_LENGTH,
                            num_classes=2,
                            vocab_size=vocab_size,
                            emb_dim=EMB_DIM,
                            dis_emb_dim=dis_embedding_dim,
                            filter_sizes=dis_filter_sizes,
                            num_filters=dis_num_filters,
                            batch_size=BATCH_SIZE,
                            hidden_dim=HIDDEN_DIM,
                            start_token=START_TOKEN,
                            goal_out_size=GOAL_OUT_SIZE,
                            goal_size=GOAL_SIZE,
                            step_size=4,
                            D_model=discriminator,
                            learning_rate=LEARNING_RATE)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        # sess.run(tf.global_variables_initializer())

        # restore weights
        save_file = os.path.join('.', 'ckp', exp_name, exp_name)
        reader = tf.train.NewCheckpointReader(save_file)
        saved_shapes = reader.get_variable_to_shape_map()
        var_names = sorted([(var.name, var.name.split(':')[0])
                            for var in tf.global_variables()
                            if var.name.split(':')[0] in saved_shapes])
        restore_vars = []
        name2var = dict(
            list(
                zip([x.name.split(':')[0] for x in tf.global_variables()],
                    tf.global_variables())))
        with tf.variable_scope('', reuse=True):
            for var_name, saved_var_name in var_names:
                curr_var = name2var[saved_var_name]
                var_shape = curr_var.get_shape().as_list()
                if var_shape == saved_shapes[saved_var_name]:
                    restore_vars.append(curr_var)
                else:
                    print(("Not loading: %s." % saved_var_name))
        saver = tf.train.Saver(restore_vars)
        print("Restoring vars:")
        print(restore_vars)
        saver.restore(sess, save_file)

        # if exp_name == 'regular_120_50_200':
        #     print('#########################################################################')
        #     print('Conducting convergence expariment...')
        #     test_data_loader = Gen_Data_loader_text8(BATCH_SIZE,charmap,inv_charmap,SEQ_LENGTH)
        #     test_data_loader.create_batches(real_data_test_file)
        #     results = convergence_experiment(sess, generator, test_data_loader)
        #     print('Saving results...')
        #     np.save('SeqGan_' + exp_name + '_conv_results',results)
        if FLAGS.dump_samples:
            print('###')
            print('Saving samples file...')
            generate_real_data_samples(sess, generator, BATCH_SIZE, BATCH_SIZE,
                                       "save/lm_eval_file_%0s.txt" % exp_name,
                                       inv_charmap)

        print('###')
        print('Start Language Model Evaluation...')
        test_data_loader = Gen_Data_loader_text8(BATCH_SIZE, charmap,
                                                 inv_charmap, SEQ_LENGTH)
        if FLAGS.test:
            test_data_loader.create_batches(real_data_test_file)
            print("USING TEXT8 TEST SET")
        else:
            test_data_loader.create_batches(real_data_valid_file)
            print("USING TEXT8 VALID SET")
        BPC_direct, BPC_approx = language_model_evaluation(sess,
                                                           generator,
                                                           test_data_loader,
                                                           is_test=FLAGS.test)

        str = "[%0s] BPC_direct = %f" % (exp_name, BPC_direct)
        with open(RESULTS_PATH, 'a') as f:
            f.write(str + '\n')
        print(str)

        if FLAGS.test:
            str = "[%0s] BPC_approx = %f" % (exp_name, BPC_approx)
            with open(RESULTS_PATH, 'a') as f:
                f.write(str + '\n')
            print(str)

        if FLAGS.epoch_exp:
            stats[0, i] = int(exp_name.split('_epoch_')[-1])
            stats[1, i] = BPC_direct

    if FLAGS.epoch_exp:
        np.save('direct_results_epoch_exp', stats)
示例#13
0
def init_data_loader(positive_file):
    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    gen_data_loader.create_batches(positive_file)
    return gen_data_loader, dis_data_loader
示例#14
0
    def train(self):

        data_save_path = self.data_path

        sentiment_i = np.where(self.data['captions'][:, 3] != 0)[0]
        captions = self.data['captions'][sentiment_i, :21]
        n_examples = captions.shape[0]
        n_iters_per_epoch = int(np.floor(float(n_examples) / self.batch_size))
        image_idxs = self.data['image_idxs'][sentiment_i]

        features = self.data['features'].reshape(-1, 49, 2048)

        val_features = self.val_data['features'].reshape(-1, 49, 2048)

        n_iters_val = int(
            np.ceil(float(val_features.shape[0]) / self.batch_size))

        with tf.variable_scope(tf.get_variable_scope()):
            loss = self.model.build_model()
            tf.get_variable_scope().reuse_variables()
            _, _, generated_captions = self.model.build_sampler(
                max_len=self.model.T - 4)

        with tf.variable_scope(tf.get_variable_scope()):
            optimizer = self.optimizer(learning_rate=self.learning_rate)
            params = [
                param for param in tf.trainable_variables()
                if not ('discriminator' in param.name)
            ]
            grads = tf.gradients(loss, params)
            grads_and_vars = list(zip(grads,
                                      params))  #tf.trainable_variables()))
            train_op = optimizer.apply_gradients(grads_and_vars=grads_and_vars)

        tf.summary.scalar('batch_loss', loss)
        for var in tf.trainable_variables():
            tf.summary.histogram(var.op.name, var)
        for grad, var in grads_and_vars:
            print var.op.name, 'ooooo'
            tf.summary.histogram(var.op.name + '/gradient', grad)

        summary_op = tf.summary.merge_all()

        print "The number of epoch: %d" % self.n_epochs
        print "Data size: %d" % n_examples
        print "Batch size: %d" % self.batch_size
        print "Iterations per epoch: %d" % n_iters_per_epoch

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True

        dis_embedding_dim = 256
        dis_filter_sizes = [
            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, self.model.T - 4
        ]
        dis_num_filters = [
            100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160
        ]
        dis_l2_reg_lambda = 0.2

        discriminator = Discriminator(sequence_length=self.model.T - 4,
                                      num_classes=2,
                                      vocab_size=self.model.V,
                                      embedding_size=dis_embedding_dim,
                                      filter_sizes=dis_filter_sizes,
                                      num_filters=dis_num_filters,
                                      l2_reg_lambda=dis_l2_reg_lambda)

        rollout = ROLLOUT(self.model, 0.8)

        dis_data_loader = Dis_dataloader(self.dis_batch_size)

        rewards = np.zeros((self.batch_size, self.model.T - 4),
                           dtype=np.float32)

        dis_results_file = open(
            os.path.join(self.model_path, 'dis_results_file_4.txt'), 'w')

        with tf.Session(config=config) as sess:

            tf.global_variables_initializer().run()
            summary_writer = tf.summary.FileWriter(
                self.log_path, graph=tf.get_default_graph())
            saver = tf.train.Saver(max_to_keep=40)

            if self.pretrained_model is not None:
                print "Start training with pretrained Model.."
                saver.restore(sess, self.pretrained_model)

            prev_loss = -1
            curr_loss = 0
            start_t = time.time()

            print 'Start pre-training...'

            for e in range(0):  #self.n_epochs):

                rand_idxs = np.random.permutation(n_examples)
                captions = captions[rand_idxs]
                image_idxs = image_idxs[rand_idxs]

                for i in range(n_iters_per_epoch):

                    captions_batch = captions[i * self.batch_size:(i + 1) *
                                              self.batch_size]
                    image_idxs_batch = image_idxs[i * self.batch_size:(i + 1) *
                                                  self.batch_size]
                    features_batch = features[image_idxs_batch]

                    feed_dict = {
                        self.model.whole_samples:
                        captions_batch[:, 4:self.model.T],
                        self.model.rewards: rewards,
                        self.model.features: features_batch,
                        self.model.captions: captions_batch,
                        self.model.mode_learning: 1
                    }

                    _, l = sess.run([train_op, loss], feed_dict)

                    curr_loss += l

                    if (i + 1) % self.print_every == 0:

                        ground_truths = captions[image_idxs ==
                                                 image_idxs_batch[0], 4:]
                        decoded = decode_captions(ground_truths,
                                                  self.model.idx_to_word)
                        for j, gt in enumerate(decoded):
                            print "Ground truth %d: %s" % (j + 1, gt)
                        feed_dict = {
                            self.model.features:
                            features_batch,
                            self.model.whole_samples:
                            captions_batch[:, 4:self.model.T],
                            self.model.nsample:
                            0,
                            self.model.mode_sampling:
                            1,
                            self.model.captions:
                            captions_batch
                        }

                        gen_caps = sess.run(generated_captions, feed_dict)
                        decoded = decode_captions(gen_caps,
                                                  self.model.idx_to_word)
                        print "Generated caption: %s\n" % decoded[0]

                print "Previous epoch loss: ", prev_loss
                print "Current epoch loss: ", curr_loss
                print "Elapsed time: ", time.time() - start_t
                prev_loss = curr_loss
                curr_loss = 0

                captions_batch = captions[0 * self.batch_size:(0 + 1) *
                                          self.batch_size]
                if self.print_bleu:

                    all_gen_cap = np.ndarray(
                        (val_features.shape[0], self.model.T - 4))
                    pos = [1]
                    neg = [-1]

                    val_features[:, :, 2048:2052] = [0, 1, 0, 1]

                    for i in range(n_iters_val):
                        features_batch = val_features[i *
                                                      self.batch_size:(i + 1) *
                                                      self.batch_size]
                        feed_dict = {
                            self.model.features:
                            features_batch,
                            self.model.whole_samples:
                            captions_batch[:, 4:self.model.T],
                            self.model.nsample:
                            0,
                            self.model.mode_sampling:
                            1,
                            self.model.captions:
                            captions_batch
                        }
                        gen_cap = sess.run(generated_captions,
                                           feed_dict=feed_dict)
                        all_gen_cap[i * self.batch_size:(i + 1) *
                                    self.batch_size] = gen_cap

                    all_decoded = decode_captions(all_gen_cap,
                                                  self.model.idx_to_word)
                    save_pickle(
                        all_decoded,
                        os.path.join(data_save_path,
                                     "val/val.candidate.captions.pkl"))
                    scores = evaluate(data_path=data_save_path,
                                      split='val',
                                      get_scores=True)

                    print "scores_pos==================", scores
                    write_bleu(scores=scores,
                               path=self.model_path,
                               epoch=e,
                               senti=pos)

                    val_features[:, :, 2048:2052] = [0, 0, 1, 2]

                    for i in range(n_iters_val):
                        features_batch = val_features[i *
                                                      self.batch_size:(i + 1) *
                                                      self.batch_size]
                        feed_dict = {
                            self.model.features:
                            features_batch,
                            self.model.whole_samples:
                            captions_batch[:, 4:self.model.T],
                            self.model.nsample:
                            0,
                            self.model.mode_sampling:
                            1,
                            self.model.captions:
                            captions_batch
                        }
                        gen_cap = sess.run(generated_captions,
                                           feed_dict=feed_dict)
                        all_gen_cap[i * self.batch_size:(i + 1) *
                                    self.batch_size] = gen_cap

                    all_decoded = decode_captions(all_gen_cap,
                                                  self.model.idx_to_word)
                    save_pickle(
                        all_decoded,
                        os.path.join(data_save_path,
                                     "val/val.candidate.captions.pkl"))
                    scores = evaluate(data_path=data_save_path,
                                      split='val',
                                      get_scores=True)
                    print "scores_neg==================", scores
                    write_bleu(scores=scores,
                               path=self.model_path,
                               epoch=e,
                               senti=neg)

                if (e + 1) % self.save_every == 0:
                    saver.save(sess,
                               os.path.join(self.model_path, 'model'),
                               global_step=e + 1)
                    print "model-%s saved." % (e + 1)

            print 'Start pre-training discriminator...'
            for e in range(0):  #self.n_epochs):

                rand_idxs = np.random.permutation(n_examples)
                captions = captions[rand_idxs]
                image_idxs = image_idxs[rand_idxs]
                dis_loss = 0
                for i in range(n_iters_per_epoch):

                    captions_batch = captions[i * self.batch_size:(i + 1) *
                                              self.batch_size]
                    image_idxs_batch = image_idxs[i * self.batch_size:(i + 1) *
                                                  self.batch_size]
                    features_batch = features[image_idxs_batch]

                    feed_dict = {
                        self.model.features: features_batch,
                        self.model.whole_samples:
                        captions_batch[:, 4:self.model.T],
                        self.model.nsample: 0,
                        self.model.mode_sampling: 1,
                        self.model.captions: captions_batch
                    }

                    for d_step in range(3):
                        negative_file = sess.run(generated_captions,
                                                 feed_dict=feed_dict)
                        positive_file = captions_batch[:, 4:self.model.T]
                        dis_data_loader.load_train_data(
                            positive_file, negative_file)
                        for it in xrange(dis_data_loader.num_batch):
                            x_batch, y_batch = dis_data_loader.next_batch()
                            feed = {
                                discriminator.input_x:
                                x_batch,
                                discriminator.input_y:
                                y_batch,
                                discriminator.dropout_keep_prob:
                                self.dis_dropout_keep_prob
                            }
                            dis_l = sess.run(discriminator.loss, feed)
                            dis_loss = dis_loss + dis_l
                            _ = sess.run(discriminator.train_op, feed)
                            _ = sess.run(discriminator.params_clip, feed)

                dis_results_file.write('The loss in epoch %i is %f \n' %
                                       (e + 1, dis_loss))
                dis_results_file.flush()

                saver.save(sess,
                           os.path.join(self.model_path, 'model_and_dis'),
                           global_step=e + 1)

            print '#########################################################################'
            print 'Start Adversarial Training...'
            for e in range(self.n_epochs):

                rand_idxs = np.random.permutation(n_examples)
                captions = captions[rand_idxs]
                image_idxs = image_idxs[rand_idxs]

                for i in range(n_iters_per_epoch):

                    captions_batch = captions[i * self.batch_size:(i + 1) *
                                              self.batch_size]
                    image_idxs_batch = image_idxs[i * self.batch_size:(i + 1) *
                                                  self.batch_size]
                    features_batch = features[image_idxs_batch]

                    feed_dict = {
                        self.model.features: features_batch,
                        self.model.whole_samples:
                        captions_batch[:, 4:self.model.T],
                        self.model.nsample: 0,
                        self.model.mode_sampling: 1,
                        self.model.captions: captions_batch
                    }
                    samples_whole = sess.run(generated_captions,
                                             feed_dict=feed_dict)

                    rewards = rollout.get_reward(sess, samples_whole,
                                                 generated_captions,
                                                 self.rollout_num,
                                                 discriminator, features_batch,
                                                 captions_batch)

                    feed_dict = {
                        self.model.whole_samples: samples_whole,
                        self.model.rewards: rewards,
                        self.model.features: features_batch,
                        self.model.captions: captions_batch,
                        self.model.mode_learning: 2
                    }
                    _, l_reward = sess.run([train_op, loss],
                                           feed_dict=feed_dict)
                    curr_loss += l_reward

                    feed_dict = {
                        self.model.features: features_batch,
                        self.model.whole_samples:
                        captions_batch[:, 4:self.model.T],
                        self.model.nsample: 0,
                        self.model.mode_sampling: 1,
                        self.model.captions: captions_batch
                    }

                    for d_step in range(3):
                        negative_file = sess.run(generated_captions,
                                                 feed_dict=feed_dict)
                        positive_file = captions_batch[:, 4:self.model.T]
                        dis_data_loader.load_train_data(
                            positive_file, negative_file)
                        for it in xrange(dis_data_loader.num_batch):
                            x_batch, y_batch = dis_data_loader.next_batch()
                            feed = {
                                discriminator.input_x:
                                x_batch,
                                discriminator.input_y:
                                y_batch,
                                discriminator.dropout_keep_prob:
                                self.dis_dropout_keep_prob
                            }
                            _ = sess.run(discriminator.train_op, feed)
                            _ = sess.run(discriminator.params_clip, feed)

                    if (i + 1) % self.print_every == 0:

                        ground_truths = captions[image_idxs ==
                                                 image_idxs_batch[0], 4:]
                        decoded = decode_captions(ground_truths,
                                                  self.model.idx_to_word)
                        for j, gt in enumerate(decoded):
                            print "Ground truth %d: %s" % (j + 1, gt)
                        feed_dict = {
                            self.model.features:
                            features_batch,
                            self.model.whole_samples:
                            captions_batch[:, 4:self.model.T],
                            self.model.nsample:
                            0,
                            self.model.mode_sampling:
                            1,
                            self.model.captions:
                            captions_batch
                        }
                        gen_caps = sess.run(generated_captions, feed_dict)
                        decoded = decode_captions(gen_caps,
                                                  self.model.idx_to_word)
                        print "Generated caption: %s\n" % decoded[0]

                print "Previous epoch loss: ", prev_loss
                print "Current epoch loss: ", curr_loss
                print "Elapsed time: ", time.time() - start_t
                prev_loss = curr_loss
                curr_loss = 0

                captions_batch = captions[0 * self.batch_size:(0 + 1) *
                                          self.batch_size]
                if self.print_bleu:
                    all_gen_cap = np.ndarray(
                        (val_features.shape[0], self.model.T - 4))

                    pos = [1]
                    neg = [-1]

                    val_features[:, :, 2048:2052] = [0, 1, 0, 1]

                    for i in range(n_iters_val):
                        features_batch = val_features[i *
                                                      self.batch_size:(i + 1) *
                                                      self.batch_size]
                        feed_dict = {
                            self.model.features:
                            features_batch,
                            self.model.whole_samples:
                            captions_batch[:, 4:self.model.T],
                            self.model.nsample:
                            0,
                            self.model.mode_sampling:
                            1,
                            self.model.captions:
                            captions_batch
                        }
                        gen_cap = sess.run(generated_captions,
                                           feed_dict=feed_dict)
                        all_gen_cap[i * self.batch_size:(i + 1) *
                                    self.batch_size] = gen_cap

                    all_decoded = decode_captions(all_gen_cap,
                                                  self.model.idx_to_word)
                    save_pickle(
                        all_decoded,
                        os.path.join(data_save_path,
                                     "val/val.candidate.captions.pkl"))
                    scores = evaluate(data_path=data_save_path,
                                      split='val',
                                      get_scores=True)

                    print "scores_pos==================", scores

                    write_bleu(scores=scores,
                               path=self.model_path,
                               epoch=e,
                               senti=pos)

                    val_features[:, :, 2048:2052] = [0, 0, 1, 2]

                    for i in range(n_iters_val):

                        features_batch = val_features[i *
                                                      self.batch_size:(i + 1) *
                                                      self.batch_size]
                        feed_dict = {
                            self.model.features:
                            features_batch,
                            self.model.whole_samples:
                            captions_batch[:, 4:self.model.T],
                            self.model.nsample:
                            0,
                            self.model.mode_sampling:
                            1,
                            self.model.captions:
                            captions_batch
                        }

                        gen_cap = sess.run(generated_captions,
                                           feed_dict=feed_dict)
                        all_gen_cap[i * self.batch_size:(i + 1) *
                                    self.batch_size] = gen_cap

                    all_decoded = decode_captions(all_gen_cap,
                                                  self.model.idx_to_word)
                    save_pickle(
                        all_decoded,
                        os.path.join(data_save_path,
                                     "val/val.candidate.captions.pkl"))
                    scores = evaluate(data_path=data_save_path,
                                      split='val',
                                      get_scores=True)

                    print "scores_neg==================", scores

                    write_bleu(scores=scores,
                               path=self.model_path,
                               epoch=e,
                               senti=neg)

                if (e + 1) % self.save_every == 0:
                    saver.save(sess,
                               os.path.join(self.model_path, 'model_adv'),
                               global_step=e + 1)
                    print "model-%s saved." % (e + 1)
示例#15
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    # with open('./save/target_params.pkl','rb') as f:
    #     print('readable:',f.readable())
    #     target_params=pickle.load(f)
    target_params = pickle.load(open('save/target_params.pkl','rb'),encoding='ISO-8859-1')
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    discriminator = Discriminator(sequence_length=20, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim, 
                                filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print ('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print ('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)

    print ('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print ('#########################################################################')
    print ('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
    print(test_loss)
示例#16
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0
    vocab_w2idx, vocab_idx2w, len_vocab_w2idx = utils.load_vocab(vacab_file)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, seq_len=SEQ_LENGTH)
    dis_data_loader = Dis_dataloader(BATCH_SIZE, seq_len=SEQ_LENGTH)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)

    discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
                                filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    gen_data_loader.create_batches(positive_file)
    print("gen_data_loader num_batches: ", gen_data_loader.num_batch)


    #  pre-train generator
    print 'Start pre-training...'
    for epoch in xrange(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 1 == 0:
            utils.test_demo(sess, generator, 8, 8, eval_file, vocab_idx2w, epoch)
            print 'pre-train epoch ', epoch, 'train_loss ', loss

    print 'Start pre-training discriminator...'
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(10):
        utils.generate_samples(sess, generator,  64, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            print 'adversarial-train epoch ', total_batch
            utils.test_demo(sess, generator, 8, 8, eval_file, vocab_idx2w, total_batch)
        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
示例#17
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    #
    # Declare data loader
    # ----------------------------------------------------------------------------
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    # ----------------------------------------------------------------------------

    #
    # Declare Generator & Discriminator
    # ----------------------------------------------------------------------------
    # declare: generator
    generator = Generator(NUM_EMB, EMB_DIM, BATCH_SIZE, HIDDEN_DIM_1, HIDDEN_DIM_2, SEQ_LENGTH_1, SEQ_LENGTH_2, START_TOKEN)

    # declare: discriminator
    discriminator = Discriminator(sequence_length=SEQ_LENGTH_2, emb_vec_dim=EMB_DIM, num_classes=2,
                                  vocab_size=1, embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes, num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)
    # ----------------------------------------------------------------------------

    #
    # Set the session <sess>
    # ----------------------------------------------------------------------------
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    # ----------------------------------------------------------------------------

    #
    # load the air data and write positive file
    gen_data_loader.load_data(root_path + positive_file, site_list, target_site, target_kind, training_year,
                              training_duration, pollution_kind, SEQ_LENGTH_1, SEQ_LENGTH_2)

    likelihood_data_loader.load_data(root_path + positive_file, site_list, target_site, target_kind, training_year,
                                     training_duration, pollution_kind, SEQ_LENGTH_1, SEQ_LENGTH_2)

    gen_data_loader.create_batches(positive_file, SEQ_LENGTH_2)

    log = open('save/experiment-log.txt', 'w')

    #
    # Pre-train <generator> by using <gen_data_loader>,
    # and then compute the <test_loss> of <target_lstm> and <likelihood_data_loader>
    # ----------------------------------------------------------------------------
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            # generate samples by using <generator> and write the samples to file <eval_file>
            generate_samples(sess, generator, BATCH_SIZE, generated_num, gen_data_loader, eval_file)

            # load samples from file <eval_file>
            likelihood_data_loader.create_batches(eval_file, SEQ_LENGTH_2)

            # compute <test_loss> of <target_lstm>, with input <likelihood_data_loader>
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)

            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)
            # ----------------------------------------------------------------------------

    print('OK')
示例#18
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()

    np.random.seed(config_train.seed)

    assert config_train.start_token == 0
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    generator = Generator(config=config_gen)
    generator.build()

    rollout_gen = rollout(config=config_gen)

    target_params = pickle.load(open('save/target_params.pkl', 'rb'),
                                encoding='iso-8859-1')
    target_lstm = TARGET_LSTM(config=config_gen,
                              params=target_params)  # The oracle model 预测模型

    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    pretrained_optimizer = tf.train.AdamOptimizer(
        config_train.gen_learning_rate)
    gradients, variables = zip(*pretrained_optimizer.compute_gradients(
        generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_update = pretrained_optimizer.apply_gradients(
        zip(gradients, variables))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    generate_samples(sess, target_lstm, config_train.batch_size,
                     config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    log = open('save/experiment-log.txt', 'w')
    print('Start pre-training generator....')

    log.write('pre-training...\n')

    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()

        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _, g_loss = sess.run(
                [gen_pre_update, generator.pretrained_loss],
                feed_dict={
                    generator.input_seqs_pre: batch,
                    generator.input_seqs_mask: np.ones_like(batch)
                })

        if epoch % config_train.test_per_epoch == 0:
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('pre-train ', epoch, ' test_loss ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess, generator, config_train.batch_size,
                         config_train.generated_num,
                         config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,
                                        config_train.negative_file)
        for _ in range(config_train.dis_update_time_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed_dict = {
                    discriminator.input_x:
                    x_batch,
                    discriminator.input_y:
                    y_batch,
                    discriminator.dropout_keep_prob:
                    config_dis.dis_dropout_keep_prob  # dropout
                }
                _ = sess.run(discriminator.train_op, feed_dict)

    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(
        generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    # Initialize global variables of optimizer for adversarial training
    uninitialized_var = [
        e for e in tf.global_variables() if e not in tf.trainable_variables()
    ]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshpae)
            feed = {'pred_seq_rollout:0': samples}
            reward_rollout = []
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,
                                        feed_dict=feed)
                rollout_list_stack = np.vstack(rollout_list)
                reward_rollout_seq = sess.run(
                    discriminator.ypred_for_auc,
                    feed_dict={
                        discriminator.input_x: rollout_list_stack,
                        discriminator.dropout_keep_prob: 1.0
                    })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,
                                           feed_dict={
                                               discriminator.input_x: samples,
                                               discriminator.dropout_keep_prob:
                                               1.0
                                           })
                reward_allseq = np.concatenate(
                    (reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(
                        r,
                        config_gen.gen_batch_size * config_gen.sequence_length,
                        config_gen.gen_batch_size)])

                reward_rollout.append(np.array(reward_tmp))
                rewards = np.sum(reward_rollout,
                                 axis=0) / config_train.rollout_num
                _, gen_loss = sess.run(
                    [train_adv_update, generator.gen_loss_adv],
                    feed_dict={
                        generator.input_seqs_adv: samples,
                        generator.rewards: rewards
                    })

        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)

        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,
                                            config_train.negative_file)

            for _ in range(config_train.dis_update_time_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:
                        x_batch,
                        discriminator.input_y:
                        y_batch,
                        discriminator.dropout_keep_prob:
                        config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
示例#19
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    vocab_size = 4839
    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)
    discriminator = Discriminator(SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  dis_emb_dim=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  batch_size=BATCH_SIZE,
                                  hidden_dim=HIDDEN_DIM,
                                  start_token=START_TOKEN,
                                  goal_out_size=GOAL_OUT_SIZE,
                                  step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,
                      num_classes=2,
                      vocab_size=vocab_size,
                      emb_dim=EMB_DIM,
                      dis_emb_dim=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      batch_size=BATCH_SIZE,
                      hidden_dim=HIDDEN_DIM,
                      start_token=START_TOKEN,
                      goal_out_size=GOAL_OUT_SIZE,
                      goal_size=GOAL_SIZE,
                      step_size=4,
                      D_model=discriminator)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    for a in range(1):
        g = sess.run(leakgan.gen_x,
                     feed_dict={
                         leakgan.drop_out: 0.8,
                         leakgan.train: 1
                     })
        print(g)
        print('epoch: %d\t' % a)

    log = open('save/experiment-log.txt', 'w')
    generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,
                     0)
    gen_data_loader.create_batches(positive_file)
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables)
    model = tf.train.latest_checkpoint(model_path)
    print(model)
    if FLAGS.restore and model:
        # model = tf.train.latest_checkpoint(model_path)
        # if model and FLAGS.restore:
        if model_path + '/' + FLAGS.model:
            print(model_path + '/' + FLAGS.model)
            saver.restore(sess, model_path + '/' + FLAGS.model)
        else:
            saver.restore(sess, model)
    else:
        if FLAGS.resD and model_path + '/' + FLAGS.model:
            print(model_path + '/' + FLAGS.model)
            saver.restore(sess, model_path + '/' + FLAGS.model)

            print('Start pre-training...')
            log.write('pre-training...\n')
            for epoch in range(PRE_EPOCH_NUM):
                loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                if epoch % 5 == 0:
                    generate_samples(sess, leakgan, BATCH_SIZE, generated_num,
                                     negative_file)
                buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                    loss) + '\n'
                log.write(buffer)
            saver.save(sess, model_path + '/leakgan_pre')
        else:
            print('Start pre-training discriminator...')
            # Train 3 epoch on the generated data and do this for 50 times
            for i in range(16):
                for _ in range(5):
                    generate_samples(sess, leakgan, BATCH_SIZE, generated_num,
                                     negative_file, 0)
                    # gen_data_loader.create_batches(positive_file)
                    dis_data_loader.load_train_data(positive_file,
                                                    negative_file)
                    for _ in range(3):
                        dis_data_loader.reset_pointer()
                        for it in range(dis_data_loader.num_batch):
                            x_batch, y_batch = dis_data_loader.next_batch()
                            feed = {
                                discriminator.D_input_x:
                                x_batch,
                                discriminator.D_input_y:
                                y_batch,
                                discriminator.dropout_keep_prob:
                                dis_dropout_keep_prob
                            }
                            D_loss, _ = sess.run([
                                discriminator.D_loss, discriminator.D_train_op
                            ], feed)
                            # print 'D_loss ', D_loss
                            buffer = str(D_loss) + '\n'
                            log.write(buffer)
                    leakgan.update_feature_function(discriminator)
                saver.save(sess, model_path + '/leakgan_preD')

                # saver.save(sess, model_path + '/leakgan')
                #  pre-train generator
                print('Start pre-training...')
                log.write('pre-training...\n')
                for epoch in range(PRE_EPOCH_NUM / 16):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        generate_samples(sess, leakgan, BATCH_SIZE,
                                         generated_num, negative_file, 0)
                    print('pre-train epoch %d, test_loss %.4f' % (epoch, loss))
                    buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                        loss) + '\n'
                    log.write(buffer)
            saver.save(sess, model_path + '/leakgan_pre')

    gencircle = 1

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            for gi in range(gencircle):
                samples = leakgan.generate(sess, 1.0, 1)
                rewards = get_reward(leakgan, discriminator, sess, samples, 4,
                                     dis_dropout_keep_prob, total_batch,
                                     gen_data_loader)
                feed = {
                    leakgan.x: samples,
                    leakgan.reward: rewards,
                    leakgan.drop_out: 1.0
                }
                _, _, g_loss, w_loss = sess.run([
                    leakgan.manager_updates, leakgan.worker_updates,
                    leakgan.goal_loss, leakgan.worker_loss
                ],
                                                feed_dict=feed)
                print('total_batch: %d\tg_loss: %.4f\tw_loss: %.4f' %
                      (total_batch, g_loss, w_loss))

        # Test
        if total_batch % 10 == 1 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num,
                             "./save/coco_" + str(total_batch) + ".txt", 0)
            saver.save(sess, model_path + '/leakgan', global_step=total_batch)
        if total_batch % 15 == 0:
            for epoch in xrange(1):
                loss = pre_train_epoch(sess, leakgan, gen_data_loader)
        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num,
                             negative_file, 0)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run(
                        [discriminator.D_loss, discriminator.D_train_op], feed)
                    # print 'D_loss ', D_loss
            leakgan.update_feature_function(discriminator)
    log.close()
示例#20
0
文件: train_op.py 项目: xljhtq/SeqGAN
def main(source_file, wordVocab, vocab_size):
    tf.reset_default_graph()
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing

    # todo:  print ("starting generating positive samples...")
    generated_num = gen_data_loader.transform_positive_file_2(
        train_dir + source_file, train_dir + positive_file, wordVocab,
        SEQ_LENGTH)
    print("generated_num: ", generated_num)
    if generated_num < 100: return
    gen_data_loader.create_batches(train_dir + positive_file)

    with tf.variable_scope("Train", reuse=None):
        generator = Generator(wordVocab,
                              vocab_size,
                              BATCH_SIZE,
                              EMB_DIM,
                              HIDDEN_DIM,
                              SEQ_LENGTH,
                              START_TOKEN,
                              learning_rate=g_lrn)

        discriminator = Discriminator(word_vocab=wordVocab,
                                      sequence_length=SEQ_LENGTH,
                                      num_classes=2,
                                      embedding_size=dis_embedding_dim,
                                      filter_sizes=dis_filter_sizes,
                                      num_filters=dis_num_filters,
                                      l2_reg_lambda=dis_l2_reg_lambda,
                                      learning_rate=d_lrn)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # todo:  1.##############pre-train generator##############
    print 'Start pre-training generator with MLE...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM_generator):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            buffer = 'epoch:\t' + str(epoch) + '\tloss:\t' + str(loss)
            print(buffer)
            sys.stdout.flush()
            log.write(buffer)

    # todo:  2.##############pre-train discriminator##############
    print 'Start pre-training discriminator...'
    for _ in range(PRE_EPOCH_NUM_discriminator):
        ## 由于是对概率分布的采样,所以每次生成的fake data数据都是不同的
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):  ## 对每批fake_data进行训练discriminator
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    with tf.variable_scope("Train", reuse=None):
        g_beta = ROLLOUT(generator, 0.8)  ## 这是表示 g_beta

    # todo:  3.############## Adversarial Training ##############
    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # todo: Train the generator for one batch samples
        for it in range(1):
            samples = generator.generate(sess)
            rewards = g_beta.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss],
                                 feed_dict=feed)

        # Test
        if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
            buffer = 'epoch:\t' + str(total_batch) + '\tg_loss:\t' + str(
                g_loss)
            print(buffer)
            sys.stdout.flush()
            log.write(buffer)

        g_beta.update_params()

        # todo: Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

        if total_batch % 10 == 0 or total_batch == TOTAL_BATCH - 1:
            out_file = out_negative_file + str(total_batch) + ".txt"
            transform_file(negative_file, wordVocab, out_file)

    generate_samples(sess, generator, BATCH_SIZE, need_generated_samples,
                     negative_file)
    transform_file(negative_file, wordVocab, source_file + ".GEN")
示例#21
0
    return sample_vocab

################################## main() #########################################

# 시간측정
start_time = time.time()

tf.reset_default_graph()

random.seed(SEED)
np.random.seed(SEED)

gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
vocab_size = len(vocab_to_int)  # 6447
print(vocab_size)
dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
                              filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

sess = tf.Session(config=config)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())

# First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
#  pre-train generator
gen_data_loader.create_batches(positive_file)
示例#22
0
def main():

    starttime = datetime.datetime.now()
    short = {}
    # graph = read_graph_edgelist(network_file)
    graph = read_graph_adjlist(network_file)
    graph_nx = nx.from_dict_of_lists(graph)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)
    generated_num = np.loadtxt(positive_file).shape[0]
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    saver = tf.train.Saver(max_to_keep=12)
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, generate true random walk path
    generate_samples_true(sess, graph, BATCH_SIZE, generated_num,
                          positive_file, SEQ_LENGTH)
    gen_data_loader.create_batches(positive_file)
    #  pre-train generator
    print('Start pre-training...')
    for epoch in range(15):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        endtime = datetime.datetime.now()
        print('pre-train epoch:', epoch, ' test_loss ', loss, ' time:',
              (endtime - starttime).seconds)

    #  pre-train discriminator
    print('Start pre-training discriminator...')
    for _ in range(3):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            d_loss_his = []
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, d_loss = sess.run(
                    [discriminator.train_op, discriminator.loss], feed)
                d_loss_his.append(d_loss)
        endtime = datetime.datetime.now()
        loss, length = dis_ypred_for_auc(sess, discriminator, generator, short,
                                         graph_nx)
        print('discriminator loss: ', np.mean(d_loss_his), 'ypred:', loss,
              'length', length, 'time: ', (endtime - starttime).seconds)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    for total_batch in range(TOTAL_BATCH):
        if total_batch % 5 == 0:
            saver.save(sess, 'log/train.checkpoint', global_step=total_batch)
        # Train the generator for one step
        for it in range(20):
            generator.graph = graph
            start_token = np.random.randint(vocab_size, size=[BATCH_SIZE])
            samples = generator.generate(sess, start_token)
            rewards, sample_temp = rollout.get_reward(sess, samples, 16,
                                                      discriminator)
            feed = {
                generator.x: samples,
                generator.rewards: rewards,
                generator.start_token: start_token
            }
            _, gen_loss = sess.run([generator.g_updates, generator.g_loss],
                                   feed_dict=feed)
            generator.graph = None
            loss, length = dis_ypred_for_auc(sess, discriminator, generator,
                                             short, graph_nx)
            endtime = datetime.datetime.now()

            print('before total_batch: ', total_batch, 'reward: ', loss,
                  'length:', length, 'test_loss: ', gen_loss, 'time: ',
                  (endtime - starttime).seconds)
        rollout.update_params()
        loss, length = dis_ypred_for_auc(sess, discriminator, generator, short,
                                         graph_nx)
        print('after total_batch: ', total_batch, 'reward: ', loss, 'length',
              length, 'time: ', (endtime - starttime).seconds)
        # Train the discriminator
        for _ in range(3):
            generator.graph = None
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            d_loss_his = []
            for _ in range(1):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, d_loss = sess.run(
                        [discriminator.train_op, discriminator.loss], feed)
                    d_loss_his.append(d_loss)
            endtime = datetime.datetime.now()
            loss, length = dis_ypred_for_auc(sess, discriminator, generator,
                                             short, graph_nx)
            print('discriminator loss: ', np.mean(d_loss_his), 'ypred:', loss,
                  'length', length, 'time: ', (endtime - starttime).seconds)

        if total_batch % 5 == 0:
            saver.save(sess, 'log/train.checkpoint', global_step=epoch)
    saver.save(sess, 'log/train.checkpoint', global_step=epoch + 1)
    #generare final fake path
    generate_samples(sess, generator, BATCH_SIZE, generated_num,
                     'save/final.txt')
    command = 'deepwalk --input example_graphs/karate.adjlist --output karate.embeddings --extra save/final.txt'
    subprocess.call(command, shell=True)
示例#23
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    with open('data/ihaiku.pickle', 'rb') as f:
        haiku_list = pickle.load(f)
    #usew2v---------------------------------------------------------------------------------------------
    with open('data/index.pickle', 'rb') as f:
        index = pickle.load(f)
    vocab_size = len(index)
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        select_haikus(haiku_list, generated_num, positive_file)
        gen_data_loader.create_batches(positive_file)
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            print('pre-train epoch ', epoch)
            buffer = 'epoch:\t' + str(epoch) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        select_haikus(haiku_list, generated_num, positive_file)
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            kigos = select_kigos(kigo_list, BATCH_SIZE)
            samples, rate = generator.generate_with_rate(sess, kigos)
            rewards = rollout.get_reward(sess, samples, 16, discriminator,
                                         rate)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            select_haikus(haiku_list, generated_num, positive_file)
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

        # Test
        print(
            'total_batch:',
            total_batch,
        )
        if total_batch - 1 % 50 == 0:
            output_file = 'result/result_{0:04d}_epoch.txt'.format(total_batch)
            generate_samples_with_pred(sess, generator, discriminator,
                                       BATCH_SIZE, generated_num, output_file)
            buffer = 'epoch:\t' + str(total_batch) + '\n'
            log.write(buffer)

    log.close()
示例#24
0
def main(FLAGS):

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    # if not use_real_world_data:
    #     target_params = pickle.load(open('save/target_params.pkl'))
    #     target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    # discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
    #                             filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)

    # # tf.reset_default_graph()
    # generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)

    # config = tf.ConfigProto()
    # config.gpu_options.allow_growth = True
    # sess = tf.Session(config=config)
    # # sess.run(tf.global_variables_initializer())

    experiments_list = [
        exp for exp in os.listdir('ckp')
        if os.path.isdir(os.path.join('ckp', exp))
    ]
    if FLAGS.epoch_exp:
        experiments_list.sort(key=lambda x: int(x.split('_epoch_')[-1]))
        stats = np.zeros([2, len(experiments_list)], dtype=np.float32)

    for i, exp_name in enumerate(experiments_list):
        print(
            '#########################################################################'
        )
        print('loading model [%0s]...' % exp_name)

        if FLAGS.epoch_exp:
            config = os.path.join(
                'ckp', 'config_' + exp_name.split('_epoch_')[0] + '.txt')
        else:
            config = os.path.join('ckp', 'config_' + exp_name + '.txt')

        # restore generator arch
        try:
            EXPERIMENT_NAME = restore_param_from_config(
                config, param='experiment_name')
            EMB_DIM = int(
                restore_param_from_config(config, param='gen_emb_dim'))
            HIDDEN_DIM = int(
                restore_param_from_config(config, param='gen_hidden_dim'))
        except:
            EMB_DIM = 32
            HIDDEN_DIM = 32
            print("WARNING: CONFIG FILE WAS NOT FOUND - USING DEFAULT CONFIG")

        try:
            TOKEN_TYPE = restore_param_from_config(config, param='base_token')
            DATA_PATH = restore_param_from_config(config, param='dataset_path')
        except ValueError:
            TOKEN_TYPE = 'char'
            DATA_PATH = './data/text8/text8'
            print(
                "WARNING: NEW CONFIGURATION WAS NOT FOUND - EVALUATING CHAR-BASED"
            )

        try:
            NUM_LAYERS = int(
                restore_param_from_config(config,
                                          param='gen_num_recurrent_layers'))
        except ValueError:
            NUM_LAYERS = 1
            print("WARNING: NUM LAYERS NOT FOUND - USING 1")

        if use_real_world_data:
            # split to train-valid-test
            real_data_train_file = DATA_PATH + '-train'
            real_data_valid_file = DATA_PATH + '-valid'
            real_data_test_file = DATA_PATH + '-test'
            real_data_dict_file = DATA_PATH + '-dict.json'
            if not os.path.exists(real_data_train_file):
                split_text8(DATA_PATH)
            charmap, inv_charmap = create_real_data_dict(
                real_data_train_file, real_data_dict_file, TOKEN_TYPE)
            vocab_size = len(charmap)
        else:
            gen_data_loader = Gen_Data_loader(BATCH_SIZE)
            likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
            vocab_size = 5000
            dis_data_loader = Dis_dataloader(BATCH_SIZE)

        tf.reset_default_graph()
        generator = Generator(vocab_size,
                              BATCH_SIZE,
                              EMB_DIM,
                              HIDDEN_DIM,
                              SEQ_LENGTH,
                              START_TOKEN,
                              num_recurrent_layers=NUM_LAYERS)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        # sess.run(tf.global_variables_initializer())

        # restore weights
        save_file = os.path.join('.', 'ckp', exp_name, exp_name)
        reader = tf.train.NewCheckpointReader(save_file)
        saved_shapes = reader.get_variable_to_shape_map()
        var_names = sorted([(var.name, var.name.split(':')[0])
                            for var in tf.global_variables()
                            if var.name.split(':')[0] in saved_shapes])
        restore_vars = []
        name2var = dict(
            list(
                zip([x.name.split(':')[0] for x in tf.global_variables()],
                    tf.global_variables())))
        with tf.variable_scope('', reuse=True):
            for var_name, saved_var_name in var_names:
                curr_var = name2var[saved_var_name]
                var_shape = curr_var.get_shape().as_list()
                if var_shape == saved_shapes[saved_var_name]:
                    restore_vars.append(curr_var)
                else:
                    print(("Not loading: %s." % saved_var_name))
        saver = tf.train.Saver(restore_vars)
        print("Restoring vars:")
        print(restore_vars)
        saver.restore(sess, save_file)

        # if exp_name == 'regular_120_50_200':
        #     print('#########################################################################')
        #     print('Conducting convergence expariment...')
        #     test_data_loader = Gen_Data_loader_text(BATCH_SIZE,charmap,inv_charmap,SEQ_LENGTH)
        #     test_data_loader.create_batches(real_data_test_file)
        #     results = convergence_experiment(sess, generator, test_data_loader)
        #     print('Saving results...')
        #     np.save('SeqGan_' + exp_name + '_conv_results',results)

        if FLAGS.dump_samples:
            print('###')
            print('Saving samples file...')
            generate_real_data_samples(
                sess, generator, BATCH_SIZE, BATCH_SIZE,
                "save/lm_eval_file_%0s.txt" % EXPERIMENT_NAME, inv_charmap,
                TOKEN_TYPE)

        print('###')
        print('Start Language Model Evaluation...')
        test_data_loader = Gen_Data_loader_text(BATCH_SIZE, charmap,
                                                inv_charmap, SEQ_LENGTH,
                                                TOKEN_TYPE)
        if FLAGS.test:
            test_data_loader.create_batches(real_data_test_file)
            print("USING %s TEST SET" % TOKEN_TYPE.upper())
        else:
            test_data_loader.create_batches(real_data_valid_file)
            # test_data_loader.create_batches(real_data_train_file)
            print("USING %s VALID SET" % TOKEN_TYPE.upper())
        BPC_direct = language_model_evaluation_direct(sess, generator,
                                                      test_data_loader)
        print("[%0s] BPC_direct = %f" % (exp_name, BPC_direct))

        if FLAGS.test:
            BPC_approx = language_model_evaluation_by_approximation(
                sess, generator, test_data_loader)
            print("[%0s] BPC_approx = %f" % (exp_name, BPC_approx))

        if FLAGS.epoch_exp:
            stats[0, i] = int(exp_name.split('_epoch_')[-1])
            stats[1, i] = BPC_direct

    if FLAGS.epoch_exp:
        np.save('direct_results_epoch_exp', stats)
示例#25
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()
    np.random.seed(config_train.seed)
    assert config_train.start_token == 0

    #Build dataloader for generaotr, testing and discriminator
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    #Build generator and its rollout
    generator = Generator(config=config_gen)
    # 生成 3个神经网络
    generator.build()
    #  快速展开网络,序列未生成完就预测后边的序列,用于计算reward
    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(config=config_gen,
                              params=target_params)  # The oracle model

    #Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    #Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(
        config_train.gen_learning_rate)
    # 取出 teller 的所有变量, teller在 generator和rollout网络中
    var_pretrained = [
        v for v in tf.trainable_variables() if 'teller' in v.name
    ]  #Using name 'teller' here to prevent name collision of target LSTM
    # zip函数将 2个迭代器  组成tuple
    gradients, variables = zip(*pretrained_optimizer.compute_gradients(
        generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_upate = pretrained_optimizer.apply_gradients(
        zip(gradients, variables))

    #Initialize all variables
    sess = tf.Session(config=config_hardware)
    sess.run(tf.global_variables_initializer())

    #Initalize data loader of generator   utils.py文件中
    #   target_lstm 网络生成真实数据 写入config_train.positive_file 文件
    generate_samples(sess, target_lstm, config_train.batch_size,
                     config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    #Start pretraining
    log = open('save/experiment-log.txt', 'w')
    print 'Start pre-training generator...'
    log.write('pre-training...\n')
    for epoch in xrange(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in xrange(gen_data_loader.num_batch):
            #见第60行,加载target_lstm 神经网络的数据,用于预训练生成器====真实样本
            batch = gen_data_loader.next_batch()
            #真实数据训练  generator;有监督学习   batch 最后第一个是label
            _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\
                                                                                    generator.input_seqs_mask:np.ones_like(batch)})
        if epoch % config_train.test_per_epoch == 0:
            #  generator 生成样本  与 真实数据的相似度
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            #评估生成质量
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print 'Start pre-training discriminator...'
    for t in range(config_train.dis_update_time_pre):
        print "Times: " + str(t)
        #   generator生成假数据+ target_lstm的真实数据;; 用于训练
        generate_samples(sess, generator, config_train.batch_size,
                         config_train.generated_num,
                         config_train.negative_file)
        #  混合真假数据
        dis_data_loader.load_train_data(config_train.positive_file,
                                        config_train.negative_file)
        for _ in range(config_train.dis_update_epoch_pre):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x:
                    x_batch,
                    discriminator.input_y:
                    y_batch,
                    discriminator.dropout_keep_prob:
                    config_dis.dis_dropout_keep_prob
                }
                #交叉上最小;  主要是训练评分网络 用于给generator提供reward
                _ = sess.run(discriminator.train_op, feed)

    #Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(
        generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    #Initialize global variables of optimizer for adversarial training
    uninitialized_var = [
        e for e in tf.global_variables() if e not in tf.trainable_variables()
    ]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    #Start adversarial training   开始对抗训练
    for total_batch in xrange(config_train.total_batch):
        for iter_gen in xrange(config_train.gen_update_time):

            #  用generator进行抽样; LSTM 生成序列
            samples = sess.run(generator.sample_word_list_reshape)

            feed = {"pred_seq_rollout:0": samples}
            reward_rollout = []
            #calcuate the reward given in the specific stpe t by roll out
            # 用rollout网络计算指定动作的回报
            for iter_roll in xrange(config_train.rollout_num):

                # 生成器采样的获得的单词传给 rollout  ??有一个疑问?samples看代码是完整序列(与论文不符),为什么还要rollout
                rollout_list = sess.run(rollout_gen.sample_rollout_step,
                                        feed_dict=feed)

                rollout_list_stack = np.vstack(
                    rollout_list
                )  #shape: #batch_size * #rollout_step, #sequence length
                # 蒙特卡洛 展开成序列,贝尔曼方程计算 reward
                reward_rollout_seq = sess.run(
                    discriminator.ypred_for_auc,
                    feed_dict={
                        discriminator.input_x: rollout_list_stack,
                        discriminator.dropout_keep_prob: 1.0
                    })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,
                                           feed_dict={
                                               discriminator.input_x: samples,
                                               discriminator.dropout_keep_prob:
                                               1.0
                                           })
                reward_allseq = np.concatenate(
                    (reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in xrange(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(
                        r,
                        config_gen.gen_batch_size * config_gen.sequence_length,
                        config_gen.gen_batch_size)])
                reward_rollout.append(np.array(reward_tmp))
            #计算reward
            rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num
            # 用reward 指导 generator 更新梯度
            _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\
                                                                                        generator.rewards:rewards})
        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            #对抗训练后 用generator再次生成样本与模拟器(target_lstm,真实数据)进行比对
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            #util.py中定义
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,
                                            config_train.negative_file)

            for _ in range(config_train.dis_update_epoch_adv):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:
                        x_batch,
                        discriminator.input_y:
                        y_batch,
                        discriminator.dropout_keep_prob:
                        config_dis.dis_dropout_keep_prob
                    }
                    #训练这个评分网络, score
                    _ = sess.run(discriminator.train_op, feed)
    log.close()
    return sample_vocab

################################## main() #########################################

# load model path (./chekckpoint)
load_model_path = './checkpoint/test4/seqGAN_ours' ##path changed

tf.reset_default_graph()

random.seed(SEED)
np.random.seed(SEED)

gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
vocab_size = len(vocab_to_int)  # 6448 (pos)
print(vocab_size)
dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)

discriminator = Discriminator(sequence_length=SEQ_LENGTH, num_classes=2, vocab_size=vocab_size, embedding_size=dis_embedding_dim,
                              filter_sizes=dis_filter_sizes, num_filters=dis_num_filters, l2_reg_lambda=dis_l2_reg_lambda)
rollout = ROLLOUT(generator, 0.8, word_embedding_matrix)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

sess = tf.Session(config=config)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())

print('#########################################################################')
示例#27
0
    sample_vocab = [[int_to_vocab[i] for i in sample] for sample in sample_int]
    sample_result = []
    for i in range(len(sample_vocab)):
        sample_result.append(
            str(type_str[i]) + ' ' + ' '.join(sample_vocab[i]))
    return sample_result


random.seed(SEED)
np.random.seed(SEED)
assert START_TOKEN == 0

gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
vocab_size = len(vocab_to_int)
print(vocab_size)
dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH,
                      START_TOKEN, TYPE_SIZE)
discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                              batch_size=BATCH_SIZE,
                              num_classes=2,
                              word_embedding_matrix=word_embedding_matrix,
                              embedding_size=dis_embedding_dim,
                              filter_sizes=dis_filter_sizes,
                              num_filters=dis_num_filters,
                              type_size=TYPE_SIZE,
                              l2_reg_lambda=dis_l2_reg_lambda)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
示例#28
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size,
                          condition_size,
                          FEATURE_NUM,
                          BATCH_SIZE,
                          EMB_DIM,
                          COND_DIM,
                          HIDDEN_DIM,
                          Z_DIM,
                          SEQ_LENGTH,
                          START_TOKEN,
                          vocab_file,
                          condition_file,
                          word_vec=word_vec)

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Checkpoint
    saver = tf.train.Saver()
    ckpt = get_ckpt(ckpt_dir)
    if ckpt is not None:
        print("Load checkpoints from: ", ckpt)
        saver.restore(sess, ckpt)

    # Load true data
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')

    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        g_loss, lstm_loss, recon_loss, kl_loss = pre_train_epoch(
            sess, generator, gen_data_loader)
        if epoch % 10 == 0:
            log.write(
                'pre-train epoch %d, g_loss: %f, lstm_loss: %f, recon_loss: %f, kl_loss: %f\n'
                % (epoch, g_loss, lstm_loss, recon_loss, kl_loss))
            print(
                'pre-train epoch %d, g_loss: %f, lstm_loss: %f, recon_loss: %f, kl_loss: %f'
                % (epoch, g_loss, lstm_loss, recon_loss, kl_loss))
            generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                             generated_num, eval_file)

            if epoch % 20 == 0:
                saver.save(sess,
                           os.path.join(ckpt_dir, 'checkpoint_' + str(epoch)))

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                         generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess, gen_data_loader.next_batch())
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # # Test
        # if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
        #     generate_samples(sess, generator, gen_data_loader, BATCH_SIZE, generated_num, eval_file)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                             generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
示例#29
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, FLAGS.length)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE,
                                             FLAGS.length)  # For testing
    vocab_size = 1000
    file = open('save/target_params.pkl', 'rb')
    target_params = cPickle.load(file)

    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)
    discriminator = Discriminator(SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  dis_emb_dim=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  batch_size=BATCH_SIZE,
                                  hidden_dim=HIDDEN_DIM,
                                  start_token=START_TOKEN,
                                  goal_out_size=GOAL_OUT_SIZE,
                                  step_size=4)
    sslgan = SSLGAN(SEQ_LENGTH,
                    num_classes=2,
                    vocab_size=vocab_size,
                    emb_dim=EMB_DIM,
                    dis_emb_dim=dis_embedding_dim,
                    filter_sizes=dis_filter_sizes,
                    num_filters=dis_num_filters,
                    batch_size=BATCH_SIZE,
                    hidden_dim=HIDDEN_DIM,
                    start_token=START_TOKEN,
                    goal_out_size=GOAL_OUT_SIZE,
                    goal_size=GOAL_SIZE,
                    step_size=4,
                    D_model=discriminator,
                    learning_rate=LEARNING_RATE)
    if SEQ_LENGTH == 10:
        target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                                  SEQ_LENGTH, START_TOKEN)  # The oracle model
    else:
        target_lstm = TARGET_LSTM20(vocab_size, BATCH_SIZE, EMB_DIM,
                                    HIDDEN_DIM, SEQ_LENGTH, START_TOKEN,
                                    target_params)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                     positive_file, 0)
    for a in range(1):
        g = sess.run(sslgan.gen_x,
                     feed_dict={
                         sslgan.drop_out: 0.8,
                         sslgan.train: 1
                     })
        print(g)

        print("epoch:", a, "  ")

    log = open('save/experiment-log.txt', 'w')
    gen_data_loader.create_batches(positive_file)
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables)
    model = tf.train.latest_checkpoint(model_path)
    print(model)
    if FLAGS.restore and model:
        # model = tf.train.latest_checkpoint(model_path)
        # if model and FLAGS.restore:
        if model_path + '/' + FLAGS.model:
            print(model_path + '/' + FLAGS.model)
            saver.restore(sess, model_path + '/' + FLAGS.model)
        else:
            saver.restore(sess, model)
    else:
        if FLAGS.resD and model_path + '/' + FLAGS.model:
            print(model_path + '/' + FLAGS.model)
            saver.restore(sess, model_path + '/' + FLAGS.model)

            print('Start pre-training...')
            log.write('pre-training...\n')
            for epoch in range(PRE_EPOCH_NUM):
                loss = pre_train_epoch(sess, sslgan, gen_data_loader)
                if epoch % 5 == 0:
                    generate_samples(sess, sslgan, BATCH_SIZE, generated_num,
                                     eval_file, 0)
                    likelihood_data_loader.create_batches(eval_file)
                    test_loss = target_loss(sess, target_lstm,
                                            likelihood_data_loader)
                    print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                    buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                        test_loss) + '\n'
                    log.write(buffer)
                    generate_samples(sess, target_lstm, BATCH_SIZE,
                                     generated_num, eval_file, 0)
                    likelihood_data_loader.create_batches(eval_file)
                    test_loss = target_loss(sess, target_lstm,
                                            likelihood_data_loader)
                    print("Groud-Truth:", test_loss)
            saver.save(sess, model_path + '/sslgan_pre')
        else:
            print('Start pre-training discriminator...')
            # Train 3 epoch on the generated data and do this for 50 times
            for i in range(10):
                for _ in range(5):
                    generate_samples(sess, sslgan, BATCH_SIZE, generated_num,
                                     negative_file, 0)
                    generate_samples(sess, target_lstm, BATCH_SIZE,
                                     generated_num, positive_file, 0)
                    # gen_data_loader.create_batches(positive_file)
                    dis_data_loader.load_train_data(positive_file,
                                                    negative_file)
                    for _ in range(3):
                        dis_data_loader.reset_pointer()
                        for it in range(dis_data_loader.num_batch):
                            x_batch, y_batch = dis_data_loader.next_batch()
                            feed = {
                                discriminator.D_input_x:
                                x_batch,
                                discriminator.D_input_y:
                                y_batch,
                                discriminator.dropout_keep_prob:
                                dis_dropout_keep_prob
                            }
                            D_loss, _ = sess.run([
                                discriminator.D_loss, discriminator.D_train_op
                            ], feed)
                            # # print 'D_loss ', D_loss
                            # buffer =  str(D_loss) + '\n'
                            # log.write(buffer)
                    sslgan.update_feature_function(discriminator)
                saver.save(sess, model_path + '/sslgan_preD')

                # saver.save(sess, model_path + '/sslgan')
                #  pre-train generator
                print('Start pre-training...')
                log.write('pre-training...\n')
                for epoch in range(PRE_EPOCH_NUM / 10):
                    loss = pre_train_epoch(sess, sslgan, gen_data_loader)
                    if epoch % 5 == 0:
                        generate_samples(sess, sslgan, BATCH_SIZE,
                                         generated_num, eval_file, 0)
                        likelihood_data_loader.create_batches(eval_file)
                        test_loss = target_loss(sess, target_lstm,
                                                likelihood_data_loader)
                        print('pre-train epoch ', epoch, 'test_loss ',
                              test_loss)
                        buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                            test_loss) + '\n'
                        log.write(buffer)
                        generate_samples(sess, target_lstm, BATCH_SIZE,
                                         generated_num, eval_file, 0)
                        likelihood_data_loader.create_batches(eval_file)
                        test_loss = target_loss(sess, target_lstm,
                                                likelihood_data_loader)
                        print("Groud-Truth:", test_loss)
            saver.save(sess, model_path + '/sslgan_pre')

    gencircle = 1
    #
    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):

            for gi in range(gencircle):
                samples = sslgan.generate(sess, 1.0, 1)
                rewards = get_reward(sslgan, discriminator, sess, samples, 4,
                                     dis_dropout_keep_prob)
                feed = {
                    sslgan.x: samples,
                    sslgan.reward: rewards,
                    sslgan.drop_out: 1.0
                }
                _, _, g_loss, w_loss = sess.run([
                    sslgan.manager_updates, sslgan.worker_updates,
                    sslgan.goal_loss, sslgan.worker_loss
                ],
                                                feed_dict=feed)
                print('total_batch: ', total_batch, "  ", g_loss, "  ", w_loss)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, sslgan, BATCH_SIZE, generated_num,
                             eval_file, 0)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)
            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                             eval_file, 0)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print("Groud-Truth:", test_loss)

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, sslgan, BATCH_SIZE, generated_num,
                             negative_file, 0)
            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                             positive_file, 0)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run(
                        [discriminator.D_loss, discriminator.D_train_op], feed)
                    # print 'D_loss ', D_loss
            sslgan.update_feature_function(discriminator)
    log.close()
示例#30
0
# vocab.append('<u_k_n_o_w_n>')
# embd.append(['0' for _ in range(embedding_size)])
# src_vocab_size = len(vocab)
# embedding = np.asarray(embd)
#vocab to int
# vocab_to_int = {}
# for i in range(src_vocab_size):
#     vocab_to_int[vocab[i]] = i

print('Glove vector loaded. Total vocab: ', src_vocab_size,
      '. embedding_size: ', embedding_size)

gen_data_loader = Gen_Data_loader(BATCH_SIZE)
#likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing

dis_data_loader = Dis_dataloader(BATCH_SIZE)

generator = Generator(src_vocab_size, BATCH_SIZE, embedding_size, HIDDEN_DIM,
                      embedding, SEQ_LENGTH, START_TOKEN, gen_filter_sizes,
                      gen_num_filters)
# target_params = cPickle.load(open('save/target_params_py3.pkl', 'rb'))
# target_lstm = TARGET_LSTM(src_vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

#TODO change discriminator's embedding layer
discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                              num_classes=2,
                              vocab_size=src_vocab_size,
                              embedding_size=embedding_size,
                              filter_sizes=dis_filter_sizes,
                              num_filters=dis_num_filters,
                              l2_reg_lambda=dis_l2_reg_lambda)
示例#31
0
def main():
    # random.seed(SEED)
    # np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing

    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator2(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN,learning_rate=0.03)


    discriminator = RNNDiscriminator2(sequence_length=SEQ_LENGTH, nrof_class=2, vocab_size=vocab_size, emb_dim=dis_embedding_dim,
                                     batch_size = dis_batch_size,hidden_dim = 2*HIDDEN_DIM, learning_rate = 0.03)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Create Saver
    saver_pretrain = tf.train.Saver(max_to_keep=10)
    saver = tf.train.Saver(max_to_keep=10)

    model_idx = 1
    fname = 'model' + str(model_idx)
    model_save_path = './Model/' + fname + '/'

    while os.path.exists(model_save_path):
        model_idx += 1
        fname = 'model' + str(model_idx)
        model_save_path = './Model/' + fname + '/'

    pre_model_save_path = './Model/' + fname + '_pre/'

    os.makedirs(model_save_path)
    os.makedirs(pre_model_save_path)
    # os.makedirs(os.path.join('./log', fname))

    pretrain_fname = fname+'_pre'

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution

    gen_data_loader.create_batches(positive_file)

    #  pre-train generator
    print 'Start pre-training...'
    early_stop_buffer = [10.]*5
    for pretrain_cnt, epoch in enumerate(xrange(PRE_EPOCH_NUM)):
        loss = pre_train_epoch(sess, generator, gen_data_loader)

        if epoch % 2 == 0:
            # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_real_file)
            test_loss = target_loss(sess, generator, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            early_stop_buffer = early_stop_buffer[1:]
            early_stop_buffer.append(test_loss)
            if all(early_stop_buffer[0] < np.asarray(early_stop_buffer[1:])):
                break

            elif all(early_stop_buffer[-1] < np.asarray(early_stop_buffer[:-1])):   # save on local min
                saver_pretrain.save(sess, os.path.join(pre_model_save_path, pretrain_fname), global_step=epoch, write_meta_graph=False)

                metagraph_filename = os.path.join(pre_model_save_path, pretrain_fname + '.meta')

                if not os.path.exists(metagraph_filename):
                    saver.export_meta_graph(metagraph_filename)

    saver.restore(sess,tf.train.latest_checkpoint(pre_model_save_path))


    print 'Start pre-training discriminator...'
    # Train 1 epoch on the generated data and do this for 50 times
    for e in range(50):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch
                    }
                _ = sess.run(discriminator.train_op, feed)
            print 'Epoch {}'.format(e)
    rollout = ROLLOUT(generator, 0.7)

    print '#########################################################################'
    print 'Start Adversarial Training...'


    early_stop_buffer = [10.] * 6
    for total_batch in range(TOTAL_BATCH):

        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, SAMP_NUM, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 2 == 0 or total_batch == TOTAL_BATCH - 1:
            # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_real_file)
            test_loss = target_loss(sess, generator, likelihood_data_loader)
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss

            # early_stop_buffer = early_stop_buffer[1:]
            # early_stop_buffer.append(test_loss)
            # if all(early_stop_buffer[0] < np.asarray(early_stop_buffer[1:])):
            #     break

            # elif all(early_stop_buffer[-1] < np.asarray(early_stop_buffer[:-1])):   # save on local min
            saver.save(sess, os.path.join(model_save_path, fname), global_step=total_batch, write_meta_graph=False)

            metagraph_filename = os.path.join(model_save_path, fname + '.meta')

            if not os.path.exists(metagraph_filename):
                saver.export_meta_graph(metagraph_filename)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(3):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)


            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {discriminator.input_x: x_batch, discriminator.input_y: y_batch }
                _ = sess.run(discriminator.train_op, feed)
示例#32
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(re_batch_size)

    # TODO: Reimpliment this class with same interface.
    # generator = GeneratorTransformer(
    #     vocab_size,
    #     BATCH_SIZE,
    #     SEQ_LENGTH,
    #     START_TOKEN
    # )
    generator = Generator(
        vocab_size,
        BATCH_SIZE,
        EMB_DIM,
        HIDDEN_DIM,
        SEQ_LENGTH,
        START_TOKEN,
        MID_LAYER_G,
    )
    # TODO: Reimpliment this class with same interface.
    rewarder = Rewarder(
        vocab_size,
        BATCH_SIZE,
        EMB_DIM * 4,
        HIDDEN_DIM * 4,
        SEQ_LENGTH,
        START_TOKEN,
        MID_LAYER_R,
        l2_reg_lambda=re_l2_reg_lambda,
    )
    target_params = pickle.load(open("save/target_params.pkl", "rb"), encoding="latin1")
    # TODO: Reimpliment this class with same interface. (target_transformer)
    # I think we leave this as is, since it's the distribution we're trying to match? (Cailin)
    target_lstm = TARGET_LSTM(
        vocab_size,
        BATCH_SIZE,
        EMB_DIM,
        HIDDEN_DIM,
        SEQ_LENGTH,
        START_TOKEN,
        target_params,
    )  # The oracle model

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(tf.global_variables())

    # First, use the oracle model to provide the positive examples,
    #   which are sampled from the oracle data distribution
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)
    # ground_loss = target_loss(sess, target_lstm, gen_data_loader)
    # print('Ground-Truth:', ground_loss)

    log = open("save/experiment-ent" + str(entropy_w), "w")
    #  pre-train generator
    if restore is False:
        print("Start pre-training...")
        log.write("pre-training...\n")
        for epoch in range(PRE_EPOCH_NUM):
            loss = pre_train_epoch(sess, generator, gen_data_loader)
            if epoch % 5 == 0:
                generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
                likelihood_data_loader.create_batches(eval_file)
                test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                print("pre-train epoch ", epoch, "test_loss ", test_loss)
                buffer = "epoch:\t" + str(epoch) + "\tnll:\t" + str(test_loss) + "\n"
                log.write(buffer)

        print("Start pre-training rewarder...")
        start = time.time()
        for _ in range(1):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(1):
                dis_data_loader.reset_pointer()
                r_losses = []
                for it in range(dis_data_loader.num_batch):
                    x_text = dis_data_loader.next_batch()
                    _, r_loss = rewarder.reward_train_step(
                        sess,
                        x_text,
                        np.ones(BATCH_SIZE),
                        1.0,
                        re_dropout_keep_prob,
                        0.01,
                    )
                    r_losses.append(r_loss)
                print("reward_loss", np.mean(r_losses))
        speed = time.time() - start
        print("Reward pre_training Speed:{:.3f}".format(speed))

        checkpoint_path = os.path.join("save", "exper_40.ckpt")
        saver.save(sess, checkpoint_path)
    else:
        print("Restore pretrained model ...")
        log.write("Restore pre-trained model...\n")
        ckpt = tf.train.get_checkpoint_state("save")
        saver.restore(sess, ckpt.model_checkpoint_path)

    # by setting the parameters to 0.0 and 1.0, we didn't use the mixed policy RL training in SeqGAN
    rollout = ROLLOUT(generator, 0.0, 1.0)

    print("#########################################################################")
    print("Start Adversarial Training...")
    log.write("adversarial training...\n")
    for total_batch in range(TOTAL_BATCH):

        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = "epoch:\t" + str(total_batch) + "\tnll:\t" + str(test_loss) + "\n"
            print("total_batch: ", total_batch, "test_loss: ", test_loss)
            log.write(buffer)

        # Train the generator for one step
        start = time.time()
        g_losses = []
        # Draw trajectories (sequences) from generator
        off_samples, off_probs = off_policy_samples(sess, rollout, BATCH_SIZE, off_num)
        avg_reward = []
        for g_it in range(1):
            # Compute MCMC reward for each trajectory
            for it in range(off_num // BATCH_SIZE):
                rewards = rollout.get_reward(sess, off_samples[it], 8, rewarder)
                avg_reward.append(rewards)
            # Perform gradient update for generator
            baseline = np.zeros(SEQ_LENGTH)
            for it in range(1):
                for it2 in range(off_num // BATCH_SIZE):
                    _, g_loss = generator.rl_train_step(
                        sess,
                        off_samples[it2],
                        avg_reward[it2],
                        baseline,
                        off_probs[it2],
                        entropy_w,
                        G_rate,
                    )
                    g_losses.append(g_loss)
        speed = time.time() - start
        print(
            "MaxentPolicy Gradient {} round, Speed:{:.3f}, Loss:{:.3f}".format(
                total_batch, speed, np.mean(g_losses)
            )
        )

        # Update roll-out parameters
        rollout.update_params()

        # Train the rewarder
        start = time.time()
        r_loss_list = []
        for _ in range(2):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_text = dis_data_loader.next_batch()
                    weights = rewarder.reward_weight(sess, x_text, generator)
                    _, r_loss = rewarder.reward_train_step(
                        sess,
                        x_text,
                        weights,
                        1,
                        re_dropout_keep_prob,
                        R_rate * np.exp(-(total_batch // R_decay)),
                    )
                    r_loss_list.append(r_loss)
        speed = time.time() - start
        print(
            "Reward training {} round, Speed:{:.3f}, Loss:{:.3f}".format(
                total_batch, speed, np.mean(r_loss_list)
            )
        )

    log.close()
示例#33
0
文件: Main.py 项目: NemoNone/LeakGAN
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE,SEQ_LENGTH)
    vocab_size = 4839
    dis_data_loader = Dis_dataloader(BATCH_SIZE,SEQ_LENGTH)
    discriminator = Discriminator(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,emb_dim=EMB_DIM,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,goal_size=GOAL_SIZE,step_size=4,D_model=discriminator)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    for a in range(1):
        g = sess.run(leakgan.gen_x,feed_dict={leakgan.drop_out:0.8,leakgan.train:1})
        print g

        print "epoch:",a,"  "

    log = open('save/experiment-log.txt', 'w')
    generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file, 0)
    gen_data_loader.create_batches(positive_file)
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables)
    model = tf.train.latest_checkpoint(model_path)
    print  model
    if FLAGS.restore and model:
        # model = tf.train.latest_checkpoint(model_path)
        # if model and FLAGS.restore:
        if model_path+'/' + FLAGS.model:
            print model_path+'/' + FLAGS.model
            saver.restore(sess, model_path+'/' + FLAGS.model)
        else:
            saver.restore(sess, model)
    else:
        if FLAGS.resD and model_path + '/' + FLAGS.model:
                print model_path + '/' + FLAGS.model
                saver.restore(sess, model_path + '/' + FLAGS.model)

                print 'Start pre-training...'
                log.write('pre-training...\n')
                for epoch in xrange(PRE_EPOCH_NUM):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file)
                    buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(loss) + '\n'
                    log.write(buffer)
                saver.save(sess, model_path + '/leakgan_pre')
        else:
                print 'Start pre-training discriminator...'
                # Train 3 epoch on the generated data and do this for 50 times
                for i in range(16):
                    for _ in range(5):
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
                        # gen_data_loader.create_batches(positive_file)
                        dis_data_loader.load_train_data(positive_file, negative_file)
                        for _ in range(3):
                            dis_data_loader.reset_pointer()
                            for it in xrange(dis_data_loader.num_batch):
                                x_batch, y_batch = dis_data_loader.next_batch()
                                feed = {
                                    discriminator.D_input_x: x_batch,
                                    discriminator.D_input_y: y_batch,
                                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                                }
                                D_loss,_ = sess.run([discriminator.D_loss,discriminator.D_train_op], feed)
                                # print 'D_loss ', D_loss
                                buffer =  str(D_loss) + '\n'
                                log.write(buffer)
                        leakgan.update_feature_function(discriminator)
                    saver.save(sess, model_path + '/leakgan_preD')

            # saver.save(sess, model_path + '/leakgan')
        #  pre-train generator
                    print 'Start pre-training...'
                    log.write('pre-training...\n')
                    for epoch in xrange(PRE_EPOCH_NUM/16):
                        loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                        if epoch % 5 == 0:
                            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
                        print 'pre-train epoch ', epoch, 'test_loss ', loss
                        buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(loss) + '\n'
                        log.write(buffer)
                saver.save(sess, model_path + '/leakgan_pre')

    gencircle = 1
    #
    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):

            for gi in range(gencircle):
                samples = leakgan.generate(sess,1.0,1)
                rewards = get_reward(leakgan, discriminator,sess, samples, 4, dis_dropout_keep_prob,total_batch,gen_data_loader)
                feed = {leakgan.x: samples, leakgan.reward: rewards,leakgan.drop_out:1.0}
                _,_,g_loss,w_loss = sess.run([leakgan.manager_updates,leakgan.worker_updates,leakgan.goal_loss,leakgan.worker_loss], feed_dict=feed)
                print 'total_batch: ', total_batch, "  ",g_loss,"  ", w_loss

        # Test
        if total_batch % 10 == 1 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, "./save/coco_" + str(total_batch) + ".txt", 0)
            saver.save(sess, model_path + '/leakgan', global_step=total_batch)
        if total_batch % 15 == 0:
             for epoch in xrange(1):
                 loss = pre_train_epoch(sess, leakgan, gen_data_loader)
        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
                    # print 'D_loss ', D_loss
            leakgan.update_feature_function(discriminator)
    log.close()
示例#34
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()

    np.random.seed(config_train.seed)

    assert config_train.start_token == 0
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    generator = Generator(config=config_gen)
    generator.build()

    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = pickle.load(open('save/target_params.pkl','rb'),encoding='iso-8859-1')
    target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model


    # Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()


    # Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    var_pretrained = [v for v in tf.trainable_variables() if 'teller' in v.name]
    gradients, variables = zip(
        *pretrained_optimizer.compute_gradients(generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_update = pretrained_optimizer.apply_gradients(zip(gradients, variables))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    generate_samples(sess,target_lstm,config_train.batch_size,config_train.generated_num,config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    log = open('save/experiment-log.txt','w')
    print('Start pre-training generator....')

    log.write('pre-training...\n')

    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _,g_loss = sess.run([gen_pre_update,generator.pretrained_loss],feed_dict={generator.input_seqs_pre:batch,
                                                                                      generator.input_seqs_mask:np.ones_like(batch)})

        if epoch % config_train.test_per_epoch == 0:
            #进行测试,通过Generator产生一批序列,
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.eval_file)
            # 创建这批序列的data-loader
            likelihood_data_loader.create_batches(config_train.eval_file)
            # 使用oracle 计算 交叉熵损失nll
            test_loss = target_loss(sess,target_lstm,likelihood_data_loader)
            # 打印并写入日志
            print('pre-train ',epoch, ' test_loss ',test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)


    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)
        for _ in range(config_train.dis_update_time_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch,y_batch = dis_data_loader.next_batch()
                feed_dict = {
                    discriminator.input_x : x_batch,
                    discriminator.input_y : y_batch,
                    discriminator.dropout_keep_prob : config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op,feed_dict)



    # Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    # Initialize global variables of optimizer for adversarial training
    uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    # Start adversarial training
    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshpae)

            feed = {'pred_seq_rollout:0':samples}
            reward_rollout = []
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,feed_dict=feed)
                # np.vstack 它是垂直(按照行顺序)的把数组给堆叠起来。
                rollout_list_stack = np.vstack(rollout_list)
                reward_rollout_seq = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:rollout_list_stack,discriminator.dropout_keep_prob:1.0
                })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:samples,discriminator.dropout_keep_prob:1.0
                })
                reward_allseq = np.concatenate((reward_rollout_seq,reward_last_tok),axis=0)[:,1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(r,config_gen.gen_batch_size * config_gen.sequence_length,config_gen.gen_batch_size)])

                reward_rollout.append(np.array(reward_tmp))
                rewards = np.sum(reward_rollout,axis = 0) / config_train.rollout_num
                _,gen_loss = sess.run([train_adv_update,generator.gen_loss_adv],feed_dict={generator.input_seqs_adv:samples,
                                                                                           generator.rewards:rewards})


        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print ('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)


        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)

            for _ in range(config_train.dis_update_time_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch,y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:x_batch,
                        discriminator.input_y:y_batch,
                        discriminator.dropout_keep_prob:config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op,feed)

    log.close()