예제 #1
0
def main():
    print 'start time : '
    print datetime.now()
    random.seed(SEED)
    np.random.seed(SEED)

    _, _, _, SEQ_LENGTH, vocab_size = cPickle.load(open(pickle_loc))
    assert START_TOKEN == 0

    discriminator = Discriminator(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,emb_dim=EMB_DIM,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,goal_size=GOAL_SIZE,step_size=4,D_model=discriminator)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 1

    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables, max_to_keep=maxModelSave)

    if FLAGS.restore and FLAGS.model:
        if model_path + '/' + FLAGS.model:
            print model_path + '/' + FLAGS.model
            saver.restore(sess, model_path + '/' + FLAGS.model)
        else:
            print "please input all arguments!"
            exit()
    else:
        print "please input all arguments!"
        exit()

    print "start sentence generate!!"
    generate_samples(sess, leakgan, BATCH_SIZE, generated_num, generate_file, 0)
    convertor(generate_file, filedir='../data/save_generator/')
    print "sentenceGenerate.py finish!"
예제 #2
0
def main():
    print 'start time : '
    print datetime.now()
    random.seed(SEED)
    np.random.seed(SEED)

    _, _, _, SEQ_LENGTH, vocab_size = cPickle.load(open(pickle_loc))
    print 'SEQ_LENGTH' , SEQ_LENGTH, 'vocab_size', vocab_size
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE,SEQ_LENGTH)
    dis_data_loader = Dis_dataloader(BATCH_SIZE,SEQ_LENGTH)
    discriminator = Discriminator(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,emb_dim=EMB_DIM,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,goal_size=GOAL_SIZE,step_size=4,D_model=discriminator)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 1
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    for a in range(1):
        g = sess.run(leakgan.gen_x,feed_dict={leakgan.drop_out:0.8,leakgan.train:1})
        print g

        print "epoch:",a,"  "

    log = open('save/experiment-log.txt', 'w')
    generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file, 0)
    gen_data_loader.create_batches(positive_file)
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables, max_to_keep=maxModelSave)

    if FLAGS.restore and FLAGS.model:
        if model_path+'/' + FLAGS.model:
            print model_path+'/' + FLAGS.model
            saver.restore(sess, model_path+'/' + FLAGS.model)
        else:
            print 'input all arguments, \"restore\" and \"model\"'
            exit()
    else:
        if FLAGS.resD and model_path + '/' + FLAGS.model:
                print model_path + '/' + FLAGS.model
                saver.restore(sess, model_path + '/' + FLAGS.model)

                print 'Start pre-training...'
                log.write('pre-training...\n')
                for epoch in xrange(PRE_EPOCH_NUM):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file)
                    buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(loss) + '\n'
                    log.write(buffer)
                saver.save(sess, model_path + '/leakgan_pre')
        else:
                #  pre-train generator
                print 'Start pre-training...'
                log.write('pre-training...\n')
                for epoch in xrange(PRE_EPOCH_NUM):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
                    print 'pre-train epoch ', epoch, 'test_loss ', loss
                    buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(loss) + '\n'
                    log.write(buffer)

                print 'Start pre-training discriminator...'
                # Train 3 epoch on the generated data and do this for 80 times
                for _ in range(PRE_EPOCH_NUM):
                    generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file, 0)
                    dis_data_loader.load_train_data(positive_file, negative_file)
                    for _ in range(3):
                        dis_data_loader.reset_pointer()
                        x_batch, y_batch = dis_data_loader.next_batch()
                        feed = {
                            discriminator.D_input_x: x_batch,
                            discriminator.D_input_y: y_batch,
                            discriminator.dropout_keep_prob: dis_dropout_keep_prob
                        }
                        D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
                        buffer = str(D_loss) + '\n'
                        log.write(buffer)
                        leakgan.update_feature_function(discriminator)

                saver.save(sess, model_path + '/leakgan_pre')

    gencircle = 1
    print '#########################################################################'
    print 'Start Adversarial Training...'
    print 'start Adv time : '
    print datetime.now()
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH // 10):
        # Train the generator for one step
        for iter1 in range(10):
            for it in range(1):

                for gi in range(gencircle):
                    samples = leakgan.generate(sess,1.0,1)
                    rewards = get_reward(leakgan, discriminator,sess, samples, 4, dis_dropout_keep_prob,(total_batch * 10 + iter1),gen_data_loader)
                    feed = {leakgan.x: samples, leakgan.reward: rewards,leakgan.drop_out:1.0}
                    _,_,g_loss,w_loss = sess.run([leakgan.manager_updates,leakgan.worker_updates,leakgan.goal_loss,leakgan.worker_loss], feed_dict=feed)
                    print 'total_batch: ', (total_batch * 10 + iter1), "  ",g_loss,"  ", w_loss

            # Test
            testFileName = "./save/movie_" + str((total_batch * 10 + iter1)) + ".txt"
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, testFileName, 0)
            convertor(testFileName, filedir='save/')
            if iter1 == 1 or (total_batch * 10 + iter1) == TOTAL_BATCH - 1:
                saver.save(sess, model_path + '/leakgan', global_step=(total_batch * 10 + iter1))
            for _ in range(15):
                generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file, 0)
                dis_data_loader.load_train_data(positive_file, negative_file)
                for _ in range(3):
                    dis_data_loader.reset_pointer()
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
                    buffer = str(D_loss) + '\n'
                    log.write(buffer)
                    leakgan.update_feature_function(discriminator)

        for epoch in xrange(5):
            loss = pre_train_epoch(sess, leakgan, gen_data_loader)
        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file, 0)
            dis_data_loader.load_train_data(positive_file, negative_file)
            for _ in range(3):
                dis_data_loader.reset_pointer()
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.D_input_x: x_batch,
                    discriminator.D_input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
                buffer = str(D_loss) + '\n'
                log.write(buffer)
                leakgan.update_feature_function(discriminator)
    log.close()

    print 'end time : '
    print datetime.now()
예제 #3
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE,SEQ_LENGTH)
    vocab_size = 4839
    dis_data_loader = Dis_dataloader(BATCH_SIZE,SEQ_LENGTH)
    discriminator = Discriminator(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,emb_dim=EMB_DIM,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,goal_size=GOAL_SIZE,step_size=4,D_model=discriminator)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    for a in range(1):
        g = sess.run(leakgan.gen_x,feed_dict={leakgan.drop_out:0.8,leakgan.train:1})
        print g

        print "epoch:",a,"  "

    log = open('save/experiment-log.txt', 'w')
    generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file, 0)
    gen_data_loader.create_batches(positive_file)
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables)
    model = tf.train.latest_checkpoint(model_path)
    print  model
    if FLAGS.restore and model:
        # model = tf.train.latest_checkpoint(model_path)
        # if model and FLAGS.restore:
        if model_path+'/' + FLAGS.model:
            print model_path+'/' + FLAGS.model
            saver.restore(sess, model_path+'/' + FLAGS.model)
        else:
            saver.restore(sess, model)
    else:
        if FLAGS.resD and model_path + '/' + FLAGS.model:
                print model_path + '/' + FLAGS.model
                saver.restore(sess, model_path + '/' + FLAGS.model)

                print 'Start pre-training...'
                log.write('pre-training...\n')
                for epoch in xrange(PRE_EPOCH_NUM):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file)
                    buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(loss) + '\n'
                    log.write(buffer)
                saver.save(sess, model_path + '/leakgan_pre')
        else:
                print 'Start pre-training discriminator...'
                # Train 3 epoch on the generated data and do this for 50 times
                for i in range(16):
                    for _ in range(5):
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
                        # gen_data_loader.create_batches(positive_file)
                        dis_data_loader.load_train_data(positive_file, negative_file)
                        for _ in range(3):
                            dis_data_loader.reset_pointer()
                            for it in xrange(dis_data_loader.num_batch):
                                x_batch, y_batch = dis_data_loader.next_batch()
                                feed = {
                                    discriminator.D_input_x: x_batch,
                                    discriminator.D_input_y: y_batch,
                                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                                }
                                D_loss,_ = sess.run([discriminator.D_loss,discriminator.D_train_op], feed)
                                # print 'D_loss ', D_loss
                                buffer =  str(D_loss) + '\n'
                                log.write(buffer)
                        leakgan.update_feature_function(discriminator)
                    saver.save(sess, model_path + '/leakgan_preD')

            # saver.save(sess, model_path + '/leakgan')
        #  pre-train generator
                    print 'Start pre-training...'
                    log.write('pre-training...\n')
                    for epoch in xrange(PRE_EPOCH_NUM/16):
                        loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                        if epoch % 5 == 0:
                            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
                        print 'pre-train epoch ', epoch, 'test_loss ', loss
                        buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(loss) + '\n'
                        log.write(buffer)
                saver.save(sess, model_path + '/leakgan_pre')

    gencircle = 1
    #
    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):

            for gi in range(gencircle):
                samples = leakgan.generate(sess,1.0,1)
                rewards = get_reward(leakgan, discriminator,sess, samples, 4, dis_dropout_keep_prob,total_batch,gen_data_loader)
                feed = {leakgan.x: samples, leakgan.reward: rewards,leakgan.drop_out:1.0}
                _,_,g_loss,w_loss = sess.run([leakgan.manager_updates,leakgan.worker_updates,leakgan.goal_loss,leakgan.worker_loss], feed_dict=feed)
                print 'total_batch: ', total_batch, "  ",g_loss,"  ", w_loss

        # Test
        if total_batch % 10 == 1 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, "./save/coco_" + str(total_batch) + ".txt", 0)
            saver.save(sess, model_path + '/leakgan', global_step=total_batch)
        if total_batch % 15 == 0:
             for epoch in xrange(1):
                 loss = pre_train_epoch(sess, leakgan, gen_data_loader)
        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
                    # print 'D_loss ', D_loss
            leakgan.update_feature_function(discriminator)
    log.close()
예제 #4
0
def main(FLAGS):

    #########################################################################################
    #  Generator  Hyper-parameters
    ######################################################################################
    EMB_DIM = FLAGS.gen_emb_dim  # 32  # embedding dimension
    HIDDEN_DIM = FLAGS.gen_hidden_dim  # 32  # hidden state dimension of lstm cell
    SEQ_LENGTH = FLAGS.seq_len  # 20  # sequence length
    START_TOKEN = 0
    PRE_EPOCH_NUM = FLAGS.pretrain_epoch_num  # 80 # supervise (maximum likelihood estimation) epochs for generator(X1) & descriminator(X5)
    SEED = 88
    BATCH_SIZE = FLAGS.batch_size  #64
    LEARNING_RATE = 0.01
    GOAL_SIZE = 16
    STEP_SIZE = 4

    #########################################################################################
    #  Discriminator  Hyper-parameters
    #########################################################################################
    dis_embedding_dim = FLAGS.dis_emb_dim  # 64
    dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
    dis_num_filters = [
        100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160
    ]
    if FLAGS.seq_len == 20:
        dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
        dis_num_filters = [
            100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160
        ]
        LEARNING_RATE = 0.0015
        # EMB_DIM = 32  # embedding dimension
        # HIDDEN_DIM = 32  # hidden state dimension of lstm cell
    elif FLAGS.seq_len == 40:
        dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 30, 40]
        dis_num_filters = [
            100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160, 160
        ]
        LEARNING_RATE = 0.0005
        # EMB_DIM = 64
        # HIDDEN_DIM = 64
    else:
        exit(0)
    print(SEQ_LENGTH)

    GOAL_OUT_SIZE = sum(dis_num_filters)

    # dis_dropout_keep_prob = 0.75
    dis_dropout_keep_prob = 1.0
    dis_l2_reg_lambda = 0.2
    dis_batch_size = FLAGS.batch_size  #64

    #########################################################################################
    #  Basic Training Parameters
    #########################################################################################
    EXPERIMENT_NAME = FLAGS.experiment_name
    TOTAL_BATCH = FLAGS.num_epochs  # 800 #num of adversarial epochs
    positive_file = 'save/real_data_%0s.txt' % EXPERIMENT_NAME
    negative_file = 'save/generator_sample_%0s.txt' % EXPERIMENT_NAME
    eval_file = "save/eval_file_%0s" % EXPERIMENT_NAME
    generated_num = 10000  # 10000
    model_path = './ckpts'

    #########################################################################################
    #  Data configurations
    #########################################################################################
    use_real_world_data = True
    real_data_file_path = './data/text8'
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    if use_real_world_data:
        vocab_size = 27
        # split to train-valid-test
        real_data_train_file = real_data_file_path + '-train'
        real_data_valid_file = real_data_file_path + '-valid'
        real_data_test_file = real_data_file_path + '-test'
        real_data_dict_file = real_data_file_path + '-dict.json'
        if not os.path.exists(real_data_train_file):
            split_text8(real_data_file_path)
        charmap, inv_charmap = create_real_data_dict(real_data_train_file,
                                                     real_data_dict_file)
        gen_data_loader = Gen_Data_loader_text8(BATCH_SIZE,
                                                charmap,
                                                inv_charmap,
                                                seq_len=SEQ_LENGTH)
        dis_data_loader = Dis_dataloader_text8(BATCH_SIZE,
                                               charmap,
                                               inv_charmap,
                                               seq_len=SEQ_LENGTH)
        #TODO
    else:
        gen_data_loader = Gen_Data_loader(BATCH_SIZE, FLAGS.length)
        likelihood_data_loader = Gen_Data_loader(BATCH_SIZE,
                                                 FLAGS.length)  # For testing
        vocab_size = 5000
        file = open('save/target_params.pkl', 'rb')
        target_params = pickle.load(file)

        dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)

    discriminator = Discriminator(SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  dis_emb_dim=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  batch_size=BATCH_SIZE,
                                  hidden_dim=HIDDEN_DIM,
                                  start_token=START_TOKEN,
                                  goal_out_size=GOAL_OUT_SIZE,
                                  step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,
                      num_classes=2,
                      vocab_size=vocab_size,
                      emb_dim=EMB_DIM,
                      dis_emb_dim=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      batch_size=BATCH_SIZE,
                      hidden_dim=HIDDEN_DIM,
                      start_token=START_TOKEN,
                      goal_out_size=GOAL_OUT_SIZE,
                      goal_size=GOAL_SIZE,
                      step_size=4,
                      D_model=discriminator,
                      learning_rate=LEARNING_RATE)

    if not use_real_world_data:
        if SEQ_LENGTH == 40:
            target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM,
                                      HIDDEN_DIM, SEQ_LENGTH,
                                      START_TOKEN)  # The oracle model
        else:
            target_lstm = TARGET_LSTM20(vocab_size, BATCH_SIZE, EMB_DIM,
                                        HIDDEN_DIM, SEQ_LENGTH, START_TOKEN,
                                        target_params)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction = 0.3
    sess = tf.Session(config=config)
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=999999)
    sess.run(tf.global_variables_initializer())

    if use_real_world_data:
        # gen_data_loader.create_batches(real_data_train_file)
        gen_data_loader.create_batches(real_data_train_file,
                                       limit_num_samples=generated_num)
        pass
    else:
        # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
        generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                         positive_file, 0)
        gen_data_loader.create_batches(positive_file)

    for a in range(1):
        g = sess.run(leakgan.gen_x,
                     feed_dict={
                         leakgan.drop_out: 0.8,
                         leakgan.train: 1
                     })
        print(g)

        print("epoch:", a, "  ")

    log = open('save/experiment-log.txt', 'w')
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables)
    model = tf.train.latest_checkpoint(model_path)
    print(model)
    if FLAGS.restore and model:
        # model = tf.train.latest_checkpoint(model_path)
        # if model and FLAGS.restore:
        if model_path + '/' + FLAGS.model:
            print(model_path + '/' + FLAGS.model)
            saver.restore(sess, model_path + '/' + FLAGS.model)
        else:
            saver.restore(sess, model)
    else:
        # if FLAGS.resD and model_path + '/' + FLAGS.model:
        if False:  #default of resD
            print(model_path + '/' + FLAGS.model)
            saver.restore(sess, model_path + '/' + FLAGS.model)

            print('Start pre-training...')
            log.write('pre-training...\n')
            for epoch in range(PRE_EPOCH_NUM):
                loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                if epoch % 5 == 0:
                    if use_real_world_data:
                        generate_real_data_samples(
                            sess, leakgan, BATCH_SIZE, generated_num,
                            eval_file + "_epoch_%0d.txt" % epoch, inv_charmap)
                        test_loss = 0  # FIXME - TEMP
                    else:
                        generate_samples(sess, leakgan, BATCH_SIZE,
                                         generated_num, eval_file, 0)
                        likelihood_data_loader.create_batches(eval_file)
                        test_loss = target_loss(sess, target_lstm,
                                                likelihood_data_loader)

                    print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                    buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                        test_loss) + '\n'
                    log.write(buffer)
                    if use_real_world_data:
                        test_loss = 0  # FIXME - TEMP
                    else:
                        generate_samples(sess, target_lstm, BATCH_SIZE,
                                         generated_num, eval_file, 0)
                        likelihood_data_loader.create_batches(eval_file)
                        test_loss = target_loss(sess, target_lstm,
                                                likelihood_data_loader)
                    print("Groud-Truth:", test_loss)
            saver.save(sess, model_path + '/leakgan_pre')
        else:
            print('Start pre-training discriminator...')
            # Train 3 epoch on the generated data and do this for 50 times
            for i in range(10):
                for _ in range(5):
                    if use_real_world_data:
                        generate_real_data_samples(sess, leakgan, BATCH_SIZE,
                                                   generated_num,
                                                   negative_file, inv_charmap)
                        dis_data_loader.load_train_data(
                            real_data_train_file, negative_file)
                    else:
                        generate_samples(sess, leakgan, BATCH_SIZE,
                                         generated_num, negative_file, 0)
                        generate_samples(sess, target_lstm, BATCH_SIZE,
                                         generated_num, positive_file, 0)
                        # gen_data_loader.create_batches(positive_file)
                        dis_data_loader.load_train_data(
                            positive_file, negative_file)
                    for _ in range(3):
                        dis_data_loader.reset_pointer()
                        for it in range(dis_data_loader.num_batch):
                            x_batch, y_batch = dis_data_loader.next_batch()
                            feed = {
                                discriminator.D_input_x:
                                x_batch,
                                discriminator.D_input_y:
                                y_batch,
                                discriminator.dropout_keep_prob:
                                dis_dropout_keep_prob
                            }
                            D_loss, _ = sess.run([
                                discriminator.D_loss, discriminator.D_train_op
                            ], feed)
                            # # print 'D_loss ', D_loss
                            # buffer =  str(D_loss) + '\n'
                            # log.write(buffer)
                    leakgan.update_feature_function(discriminator)
                saver.save(sess, model_path + '/leakgan_preD')

                # saver.save(sess, model_path + '/leakgan')
                #  pre-train generator
                print('Start pre-training...')
                log.write('pre-training...\n')
                for epoch in range(PRE_EPOCH_NUM // 10):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        if use_real_world_data:
                            generate_real_data_samples(
                                sess, leakgan, BATCH_SIZE, generated_num,
                                eval_file + "_epoch_%0d.txt" % epoch,
                                inv_charmap)
                            test_loss = 0  # FIXME - TEMP
                        else:
                            generate_samples(sess, leakgan, BATCH_SIZE,
                                             generated_num, eval_file, 0)
                            likelihood_data_loader.create_batches(eval_file)
                            test_loss = target_loss(sess, target_lstm,
                                                    likelihood_data_loader)
                        print('pre-train epoch ', epoch, 'test_loss ',
                              test_loss)
                        buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                            test_loss) + '\n'
                        log.write(buffer)
                        if use_real_world_data:
                            test_loss = 0  # FIXME - TEMP
                        else:
                            generate_samples(sess, target_lstm, BATCH_SIZE,
                                             generated_num, eval_file, 0)
                            likelihood_data_loader.create_batches(eval_file)
                            test_loss = target_loss(sess, target_lstm,
                                                    likelihood_data_loader)
                        print("Groud-Truth:", test_loss)
            saver.save(sess, model_path + '/leakgan_pre')

    gencircle = 1
    #
    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        print("start epoch %0d" % total_batch)

        if total_batch % FLAGS.save_each_epochs == 0:
            print(
                '#########################################################################'
            )
            print('saving model...')
            save_file = os.path.join(
                '.', 'ckp', EXPERIMENT_NAME + '_epoch_%0d' % total_batch,
                EXPERIMENT_NAME + '_epoch_%0d' % total_batch)
            saver.save(sess, save_file)
        for it in range(1):

            for gi in range(gencircle):
                samples = leakgan.generate(sess, 1.0, 1)
                rewards = get_reward(leakgan, discriminator, sess, samples, 4,
                                     dis_dropout_keep_prob)
                feed = {
                    leakgan.x: samples,
                    leakgan.reward: rewards,
                    leakgan.drop_out: 1.0
                }
                _, _, g_loss, w_loss = sess.run([
                    leakgan.manager_updates, leakgan.worker_updates,
                    leakgan.goal_loss, leakgan.worker_loss
                ],
                                                feed_dict=feed)
                print('total_batch: ', total_batch, "  ", g_loss, "  ", w_loss)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            if not use_real_world_data:
                generate_samples(sess, leakgan, BATCH_SIZE, generated_num,
                                 eval_file, 0)
                likelihood_data_loader.create_batches(eval_file)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)
                buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                    test_loss) + '\n'
                print('total_batch: ', total_batch, 'test_loss: ', test_loss)
                log.write(buffer)

                generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                                 eval_file, 0)
                likelihood_data_loader.create_batches(eval_file)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)
                print("Groud-Truth:", test_loss)

        # Train the discriminator
        for _ in range(5):
            if use_real_world_data:
                generate_real_data_samples(sess, leakgan, BATCH_SIZE,
                                           generated_num, negative_file,
                                           inv_charmap)
                dis_data_loader.load_train_data(real_data_train_file,
                                                negative_file)
            else:
                generate_samples(sess, leakgan, BATCH_SIZE, generated_num,
                                 negative_file, 0)
                generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                                 positive_file, 0)
                dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run(
                        [discriminator.D_loss, discriminator.D_train_op], feed)
                    # print 'D_loss ', D_loss
            leakgan.update_feature_function(discriminator)

    print(
        '#########################################################################'
    )
    print('saving model...')
    save_file = os.path.join('.', 'ckp', EXPERIMENT_NAME, EXPERIMENT_NAME)
    saver.save(sess, save_file)

    #
    # print '#########################################################################'
    # print 'Start Language Model Evaluation...'
    # test_data_loader = Gen_Data_loader_text8(BATCH_SIZE,charmap,inv_charmap)
    # test_data_loader.create_batches(real_data_test_file)
    # language_model_evaluation(sess,generator, test_data_loader)
    log.close()
예제 #5
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, FLAGS.length)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, FLAGS.length)  # For testing
    vocab_size = 5000
    file = open('save/target_params.pkl', 'rb')
    target_params = cPickle.load(file)

    dis_data_loader = Dis_dataloader(BATCH_SIZE, SEQ_LENGTH)
    discriminator = Discriminator(SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  dis_emb_dim=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  batch_size=BATCH_SIZE,
                                  hidden_dim=HIDDEN_DIM,
                                  start_token=START_TOKEN,
                                  goal_out_size=GOAL_OUT_SIZE,
                                  step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,
                      num_classes=2,
                      vocab_size=vocab_size,
                      emb_dim=EMB_DIM,
                      dis_emb_dim=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      batch_size=BATCH_SIZE,
                      hidden_dim=HIDDEN_DIM,
                      start_token=START_TOKEN,
                      goal_out_size=GOAL_OUT_SIZE,
                      goal_size=GOAL_SIZE,
                      step_size=4,
                      D_model=discriminator,
                      learning_rate=LEARNING_RATE)
    if SEQ_LENGTH == 40:
        target_lstm = TARGET_LSTM(vocab_size,
                                  BATCH_SIZE,
                                  EMB_DIM,
                                  HIDDEN_DIM,
                                  SEQ_LENGTH,
                                  START_TOKEN)  # The oracle model
    else:
        target_lstm = TARGET_LSTM20(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file, 0)
    for a in range(1):
        g = sess.run(leakgan.gen_x, feed_dict={leakgan.drop_out: 0.8, leakgan.train: 1})
        print(g)
        print("epoch:", a, "  ")

    log = open('save/experiment-log.txt', 'w')
    gen_data_loader.create_batches(positive_file)
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables)
    model = tf.train.latest_checkpoint(model_path)
    print(model)
    if FLAGS.restore and model:
        # model = tf.train.latest_checkpoint(model_path)
        # if model and FLAGS.restore:
        if model_path + '/' + FLAGS.model:
            print(model_path + '/' + FLAGS.model)
            saver.restore(sess, model_path + '/' + FLAGS.model)
        else:
            saver.restore(sess, model)
    else:
        if FLAGS.resD and model_path + '/' + FLAGS.model:
            print(model_path + '/' + FLAGS.model)
            saver.restore(sess, model_path + '/' + FLAGS.model)

            print('Start pre-training...')
            log.write('pre-training...\n')
            for epoch in range(PRE_EPOCH_NUM):
                loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                if epoch % 5 == 0:
                    generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file, 0)
                    likelihood_data_loader.create_batches(eval_file)
                    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                    print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                    buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
                    log.write(buffer)
                    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
                    likelihood_data_loader.create_batches(eval_file)
                    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                    print("Groud-Truth:", test_loss)
            saver.save(sess, model_path + '/leakgan_pre')
        else:
            print('Start pre-training discriminator...')
            for i in range(10):
                for _ in range(5):
                    generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file, 0)
                    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file, 0)
                    # gen_data_loader.create_batches(positive_file)
                    dis_data_loader.load_train_data(positive_file, negative_file)
                    for _ in range(3):
                        dis_data_loader.reset_pointer()
                        for it in range(dis_data_loader.num_batch):
                            x_batch, y_batch = dis_data_loader.next_batch()
                            feed = {
                                discriminator.D_input_x: x_batch,
                                discriminator.D_input_y: y_batch,
                                discriminator.dropout_keep_prob: dis_dropout_keep_prob
                            }
                            D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
                    print ("D_loss: ", D_loss)
                    leakgan.update_feature_function(discriminator)  ## todo: is important
                saver.save(sess, model_path + '/leakgan_preD')

                print('Start pre-training generator...')
                log.write('pre-training...\n')
                for epoch in range(PRE_EPOCH_NUM / 10):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        print ("MLE Generator Loss: ", loss)
                        # generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file, 0)
                        # likelihood_data_loader.create_batches(eval_file)
                        # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                        # print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                        # buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
                        # log.write(buffer)
                        # generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
                        # likelihood_data_loader.create_batches(eval_file)
                        # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                        # print("Groud-Truth:", test_loss)
            saver.save(sess, model_path + '/leakgan_pre')

    gencircle = 1
    #
    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            for gi in range(gencircle):
                samples = leakgan.generate(sess, 1.0, 1)
                rewards = get_reward(leakgan,
                                     discriminator,
                                     sess,
                                     samples,
                                     4,
                                     dis_dropout_keep_prob)
                feed = {leakgan.x: samples,
                        leakgan.reward: rewards,
                        leakgan.drop_out: 0.5}
                _, _, g_loss, w_loss = sess.run(
                    [leakgan.manager_updates, leakgan.worker_updates, leakgan.goal_loss, leakgan.worker_loss],
                    feed_dict=feed)
                print('total_batch: ', total_batch, "  ", g_loss, "  ", w_loss)

        # Test
        # if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
        #     generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file, 0)
        #     likelihood_data_loader.create_batches(eval_file)
        #     test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        #     buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
        #     print('total_batch: ', total_batch, 'test_loss: ', test_loss)
        #     log.write(buffer)
        #     generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
        #     likelihood_data_loader.create_batches(eval_file)
        #     test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        #     print("Groud-Truth:", test_loss)

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file, 0)
            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file, 0)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
            leakgan.update_feature_function(discriminator)
    log.close()
예제 #6
0
파일: Main.py 프로젝트: NemoNone/LeakGAN
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE,FLAGS.length)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE,FLAGS.length) # For testing
    vocab_size = 5000
    file = open('save/target_params.pkl', 'rb')
    target_params = cPickle.load(file)
    
    dis_data_loader = Dis_dataloader(BATCH_SIZE,SEQ_LENGTH)
    discriminator = Discriminator(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,step_size=4)
    leakgan = LeakGAN(SEQ_LENGTH,num_classes=2,vocab_size=vocab_size,emb_dim=EMB_DIM,dis_emb_dim=dis_embedding_dim,filter_sizes=dis_filter_sizes,num_filters=dis_num_filters,
                        batch_size=BATCH_SIZE,hidden_dim=HIDDEN_DIM,start_token=START_TOKEN,goal_out_size=GOAL_OUT_SIZE,goal_size=GOAL_SIZE,step_size=4,D_model=discriminator,
                      learning_rate=LEARNING_RATE)
    if SEQ_LENGTH == 40:
        target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)  # The oracle model
    else:
        target_lstm = TARGET_LSTM20(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN,target_params)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file, 0)
    for a in range(1):
        g = sess.run(leakgan.gen_x,feed_dict={leakgan.drop_out:0.8,leakgan.train:1})
        print(g)

        print("epoch:",a,"  ")

    log = open('save/experiment-log.txt', 'w')
    gen_data_loader.create_batches(positive_file)
    saver_variables = tf.global_variables()
    saver = tf.train.Saver(saver_variables)
    model = tf.train.latest_checkpoint(model_path)
    print(model)
    if FLAGS.restore and model:
        # model = tf.train.latest_checkpoint(model_path)
        # if model and FLAGS.restore:
        if model_path+'/' + FLAGS.model:
            print(model_path+'/' + FLAGS.model)
            saver.restore(sess, model_path+'/' + FLAGS.model)
        else:
            saver.restore(sess, model)
    else:
        if FLAGS.resD and model_path + '/' + FLAGS.model:
                print(model_path + '/' + FLAGS.model)
                saver.restore(sess, model_path + '/' + FLAGS.model)

                print('Start pre-training...')
                log.write('pre-training...\n')
                for epoch in range(PRE_EPOCH_NUM):
                    loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                    if epoch % 5 == 0:
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file, 0)
                        likelihood_data_loader.create_batches(eval_file)
                        test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                        print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                        buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
                        log.write(buffer)
                        generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
                        likelihood_data_loader.create_batches(eval_file)
                        test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                        print("Groud-Truth:", test_loss)
                saver.save(sess, model_path + '/leakgan_pre')
        else:
                print('Start pre-training discriminator...')
                # Train 3 epoch on the generated data and do this for 50 times
                for i in range(10):
                    for _ in range(5):
                        generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
                        generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file,0)
                        # gen_data_loader.create_batches(positive_file)
                        dis_data_loader.load_train_data(positive_file, negative_file)
                        for _ in range(3):
                            dis_data_loader.reset_pointer()
                            for it in range(dis_data_loader.num_batch):
                                x_batch, y_batch = dis_data_loader.next_batch()
                                feed = {
                                    discriminator.D_input_x: x_batch,
                                    discriminator.D_input_y: y_batch,
                                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                                }
                                D_loss,_ = sess.run([discriminator.D_loss,discriminator.D_train_op], feed)
                                # # print 'D_loss ', D_loss
                                # buffer =  str(D_loss) + '\n'
                                # log.write(buffer)
                        leakgan.update_feature_function(discriminator)
                    saver.save(sess, model_path + '/leakgan_preD')

            # saver.save(sess, model_path + '/leakgan')
        #  pre-train generator
                    print('Start pre-training...')
                    log.write('pre-training...\n')
                    for epoch in range(PRE_EPOCH_NUM/10):
                        loss = pre_train_epoch(sess, leakgan, gen_data_loader)
                        if epoch % 5 == 0:
                            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file,0)
                            likelihood_data_loader.create_batches(eval_file)
                            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                            buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
                            log.write(buffer)
                            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
                            likelihood_data_loader.create_batches(eval_file)
                            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
                            print("Groud-Truth:", test_loss)
                saver.save(sess, model_path + '/leakgan_pre')

    gencircle = 1
    #
    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):

            for gi in range(gencircle):
                samples = leakgan.generate(sess,1.0,1)
                rewards = get_reward(leakgan, discriminator,sess, samples, 4, dis_dropout_keep_prob)
                feed = {leakgan.x: samples, leakgan.reward: rewards,leakgan.drop_out:1.0}
                _,_,g_loss,w_loss = sess.run([leakgan.manager_updates,leakgan.worker_updates,leakgan.goal_loss,leakgan.worker_loss], feed_dict=feed)
                print('total_batch: ', total_batch, "  ",g_loss,"  ", w_loss)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, eval_file,0)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)
            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file, 0)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print("Groud-Truth:" ,test_loss)

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, leakgan, BATCH_SIZE, generated_num, negative_file,0)
            generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file,0)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.D_input_x: x_batch,
                        discriminator.D_input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    D_loss, _ = sess.run([discriminator.D_loss, discriminator.D_train_op], feed)
                    # print 'D_loss ', D_loss
            leakgan.update_feature_function(discriminator)
    log.close()
def main(FLAGS):

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    if use_real_world_data:
        vocab_size = 27
        # split to train-valid-test
        real_data_train_file = real_data_file_path + '-train'
        real_data_valid_file = real_data_file_path + '-valid'
        real_data_test_file = real_data_file_path + '-test'
        real_data_dict_file = real_data_file_path + '-dict.json'
        if not os.path.exists(real_data_train_file):
            split_text8(real_data_file_path)
        charmap, inv_charmap = create_real_data_dict(real_data_train_file,
                                                     real_data_dict_file)
        # gen_data_loader = Gen_Data_loader_text8(BATCH_SIZE,charmap,inv_charmap,SEQ_LENGTH)
        # dis_data_loader = Dis_dataloader_text8(BATCH_SIZE,charmap,inv_charmap,SEQ_LENGTH)
    else:
        gen_data_loader = Gen_Data_loader(BATCH_SIZE)
        likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
        vocab_size = 5000
        dis_data_loader = Dis_dataloader(BATCH_SIZE)

    experiments_list = [
        exp for exp in os.listdir('ckp')
        if os.path.isdir(os.path.join('ckp', exp))
    ]
    if FLAGS.epoch_exp:
        experiments_list.sort(key=lambda x: int(x.split('_epoch_')[-1]))
        stats = np.zeros([2, len(experiments_list)], dtype=np.float32)

    for i, exp_name in enumerate(experiments_list):
        print(
            '#########################################################################'
        )
        print('loading model [%0s]...' % exp_name)

        # restore generator arch
        try:
            if FLAGS.epoch_exp:
                config = os.path.join(
                    'ckp', 'config_' + exp_name.split('_epoch_')[0] + '.txt')
            else:
                config = os.path.join('ckp', 'config_' + exp_name + '.txt')
            EMB_DIM = restore_param_from_config(config, param='gen_emb_dim')
            HIDDEN_DIM = restore_param_from_config(config,
                                                   param='gen_hidden_dim')
            dis_embedding_dim = restore_param_from_config(config,
                                                          param='dis_emb_dim')
            seq_len = restore_param_from_config(config, param='seq_len')
        except:
            print("ERROR: CONFIG FILE WAS NOT FOUND - skipping model")
            continue
        assert type(EMB_DIM) == int
        assert type(HIDDEN_DIM) == int
        assert type(dis_embedding_dim) == int
        assert type(seq_len) == int

        if seq_len == 20:
            dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
            dis_num_filters = [
                100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160
            ]
            LEARNING_RATE = 0.0015
        elif seq_len == 40:
            dis_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 30, 40]
            dis_num_filters = [
                100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160, 160
            ]
            LEARNING_RATE = 0.0005
        else:
            exit(0)
        print(SEQ_LENGTH)

        GOAL_OUT_SIZE = sum(dis_num_filters)
        GOAL_SIZE = 16

        tf.reset_default_graph()
        # generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
        discriminator = Discriminator(SEQ_LENGTH,
                                      num_classes=2,
                                      vocab_size=vocab_size,
                                      dis_emb_dim=dis_embedding_dim,
                                      filter_sizes=dis_filter_sizes,
                                      num_filters=dis_num_filters,
                                      batch_size=BATCH_SIZE,
                                      hidden_dim=HIDDEN_DIM,
                                      start_token=START_TOKEN,
                                      goal_out_size=GOAL_OUT_SIZE,
                                      step_size=4)
        generator = LeakGAN(SEQ_LENGTH,
                            num_classes=2,
                            vocab_size=vocab_size,
                            emb_dim=EMB_DIM,
                            dis_emb_dim=dis_embedding_dim,
                            filter_sizes=dis_filter_sizes,
                            num_filters=dis_num_filters,
                            batch_size=BATCH_SIZE,
                            hidden_dim=HIDDEN_DIM,
                            start_token=START_TOKEN,
                            goal_out_size=GOAL_OUT_SIZE,
                            goal_size=GOAL_SIZE,
                            step_size=4,
                            D_model=discriminator,
                            learning_rate=LEARNING_RATE)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        # sess.run(tf.global_variables_initializer())

        # restore weights
        save_file = os.path.join('.', 'ckp', exp_name, exp_name)
        reader = tf.train.NewCheckpointReader(save_file)
        saved_shapes = reader.get_variable_to_shape_map()
        var_names = sorted([(var.name, var.name.split(':')[0])
                            for var in tf.global_variables()
                            if var.name.split(':')[0] in saved_shapes])
        restore_vars = []
        name2var = dict(
            list(
                zip([x.name.split(':')[0] for x in tf.global_variables()],
                    tf.global_variables())))
        with tf.variable_scope('', reuse=True):
            for var_name, saved_var_name in var_names:
                curr_var = name2var[saved_var_name]
                var_shape = curr_var.get_shape().as_list()
                if var_shape == saved_shapes[saved_var_name]:
                    restore_vars.append(curr_var)
                else:
                    print(("Not loading: %s." % saved_var_name))
        saver = tf.train.Saver(restore_vars)
        print("Restoring vars:")
        print(restore_vars)
        saver.restore(sess, save_file)

        # if exp_name == 'regular_120_50_200':
        #     print('#########################################################################')
        #     print('Conducting convergence expariment...')
        #     test_data_loader = Gen_Data_loader_text8(BATCH_SIZE,charmap,inv_charmap,SEQ_LENGTH)
        #     test_data_loader.create_batches(real_data_test_file)
        #     results = convergence_experiment(sess, generator, test_data_loader)
        #     print('Saving results...')
        #     np.save('SeqGan_' + exp_name + '_conv_results',results)
        if FLAGS.dump_samples:
            print('###')
            print('Saving samples file...')
            generate_real_data_samples(sess, generator, BATCH_SIZE, BATCH_SIZE,
                                       "save/lm_eval_file_%0s.txt" % exp_name,
                                       inv_charmap)

        print('###')
        print('Start Language Model Evaluation...')
        test_data_loader = Gen_Data_loader_text8(BATCH_SIZE, charmap,
                                                 inv_charmap, SEQ_LENGTH)
        if FLAGS.test:
            test_data_loader.create_batches(real_data_test_file)
            print("USING TEXT8 TEST SET")
        else:
            test_data_loader.create_batches(real_data_valid_file)
            print("USING TEXT8 VALID SET")
        BPC_direct, BPC_approx = language_model_evaluation(sess,
                                                           generator,
                                                           test_data_loader,
                                                           is_test=FLAGS.test)

        str = "[%0s] BPC_direct = %f" % (exp_name, BPC_direct)
        with open(RESULTS_PATH, 'a') as f:
            f.write(str + '\n')
        print(str)

        if FLAGS.test:
            str = "[%0s] BPC_approx = %f" % (exp_name, BPC_approx)
            with open(RESULTS_PATH, 'a') as f:
                f.write(str + '\n')
            print(str)

        if FLAGS.epoch_exp:
            stats[0, i] = int(exp_name.split('_epoch_')[-1])
            stats[1, i] = BPC_direct

    if FLAGS.epoch_exp:
        np.save('direct_results_epoch_exp', stats)