Exemplo n.º 1
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    # assert START_TOKEN == 0

    vocab_size = NUM_EMB
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          MAX_LENGTH, START_TOKEN)
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                              MAX_LENGTH, 0)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(sequence_length=MAX_LENGTH,
                      num_classes=2,
                      vocab_size=vocab_size,
                      embedding_size=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [
        param for param in tf.trainable_variables()
        if 'discriminator' in param.name
    ]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss,
                                                         cnn_params,
                                                         aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars,
                                                 global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    def train_discriminator():
        if D_WEIGHT == 0:
            return 0, 0

        negative_samples = generate_samples(sess, generator, BATCH_SIZE,
                                            POSITIVE_NUM)

        #        global positive_samples
        #        pos_new=positive_samples
        # random 10% of positive samples are labeled negatively to weaken generator and avoid collapsing training
        #        random.shuffle(pos_new)
        #        length=len(pos_new)
        #        fake_neg_number= int(0.05*length)
        #        fake_neg= pos_new[:fake_neg_number]
        #        pos_new=pos_new[fake_neg_number:]

        #       negative_samples+=fake_neg
        #      random.shuffle(negative_samples)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_samples, negative_samples)
        dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train),
                                                 dis_batch_size,
                                                 dis_num_epochs)

        ypred = 0
        counter = 0
        for batch in dis_batches:
            x_batch, y_batch = zip(*batch)
            feed = {
                cnn.input_x: x_batch,
                cnn.input_y: y_batch,
                cnn.dropout_keep_prob: dis_dropout_keep_prob
            }
            _, step, loss, accuracy, ypred_for_auc = sess.run([
                dis_train_op, dis_global_step, cnn.loss, cnn.accuracy,
                cnn.ypred_for_auc
            ], feed)
            ypred_vect = np.array([item[1] for item in ypred_for_auc])
            ypred += np.mean(ypred_vect)
            counter += 1
        ypred = float(ypred) / counter
        print('\tD loss  :   {}'.format(loss))
        print('\tAccuracy: {}'.format(accuracy))
        print('\tMean ypred: {}'.format(ypred))
        return loss, accuracy, ypred

    # Pretrain is checkpointed and only execcutes if we don't find a checkpoint
    saver = tf.train.Saver()

    # We check previous session and pretrain is checkpointed and only execcutes if we don't find a checkpoint
    saver = tf.train.Saver()

    #check previous session
    prev_sess = False
    ckpt_dir = 'checkpoints/mingan'

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
#    ckpt_file = os.path.join(ckpt_dir, ckpt_dir + '_model')   #old checkpoint
    ckpt_file = os.path.join(
        ckpt_dir, 'drd2_new' + '_model_'
    )  #new checkpoint iterate over checkpoints to find largest total a

    nbatches_max = 0
    for i in range(500):  #maximal number of batches iterations is 500
        if os.path.isfile(ckpt_file + str(i) +
                          '.meta'):  #and params["LOAD_PREV_SESS"]
            nbatches_max = i

#end try find max checkpoint
    ckpt_file = ckpt_file + str(nbatches_max) + '.meta'
    if params["LOAD_PREV_SESS"]:  # and os.path.isfile(ckpt_file):
        #        saver_test = tf.train.import_meta_graph(ckpt_file)
        #        sess.run(tf.global_variables_initializer())
        saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir))

        #        saver.restore(sess, ckpt_file)
        print('Previous session loaded from previous checkpoint {}'.format(
            ckpt_file))
        prev_sess = True
    else:
        if params["LOAD_PREV_SESS"]:
            print('\t* No previous session data found as {:s}.'.format(
                ckpt_file))
        else:
            print('\t* LOAD_PREV_SESS was set to false.')

#        sess.run(tf.global_variables_initializer())
#     pretrain(sess, generator, target_lstm, train_discriminator)
#     path = saver.save(sess, ckpt_file)
#    print('Pretrain finished and saved at {}'.format(path))

    if prev_sess == False:
        #check pretraining
        ckpt_dir = 'checkpoints/{}_pretrain'.format(PREFIX)
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        ckpt_file = os.path.join(ckpt_dir, 'pretrain_ckpt')
        if os.path.isfile(ckpt_file + '.meta') and params["LOAD_PRETRAIN"]:
            saver.restore(sess, ckpt_file)
            print('Pretrain loaded from previous checkpoint {}'.format(
                ckpt_file))
        else:
            if params["LOAD_PRETRAIN"]:
                print('\t* No pre-training data found as {:s}.'.format(
                    ckpt_file))
            else:
                print('\t* LOAD_PRETRAIN was set to false.')

            sess.run(tf.global_variables_initializer())
            pretrain(sess, generator, target_lstm, train_discriminator)
            path = saver.save(sess, ckpt_file)
            print('Pretrain finished and saved at {}'.format(path))


#end loading previous session or pre-training

# create reward function
    batch_reward = make_reward(train_samples)

    rollout = ROLLOUT(generator, 0.8)

    #    nbatches_max= 30

    print(
        '#########################################################################'
    )
    print('Start Reinforcement Training Generator...')
    results_rows = []

    if nbatches_max + 1 > TOTAL_BATCH:
        print(
            ' We already trained that many batches: Check the Checkpoints folder or take a larger TOTAL_BATCH'
        )
    else:
        for nbatch in tqdm(range(nbatches_max + 1, TOTAL_BATCH)):

            #for nbatch in tqdm(range(TOTAL_BATCH)):
            results = OrderedDict({'exp_name': PREFIX})
            if nbatch % 1 == 0 or nbatch == TOTAL_BATCH - 1:
                print('* Making samples')
                if nbatch % 10 == 0:
                    gen_samples = generate_samples(sess, generator, BATCH_SIZE,
                                                   BIG_SAMPLE_NUM)
                else:
                    gen_samples = generate_samples(sess, generator, BATCH_SIZE,
                                                   SAMPLE_NUM)
                likelihood_data_loader.create_batches(gen_samples)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)
                print('batch_num: {}'.format(nbatch))
                print('test_loss: {}'.format(test_loss))
                results['Batch'] = nbatch
                results['test_loss'] = test_loss

                if test_loss < best_score:
                    best_score = test_loss
                    print('best score: %f' % test_loss)

                # results
                mm.compute_results(gen_samples, train_samples, ord_dict,
                                   results)

            print(
                '#########################################################################'
            )
            print('-> Training generator with RL.')
            print('G Epoch {}'.format(nbatch))

            for it in range(TRAIN_ITER):
                samples = generator.generate(sess)
                rewards = rollout.get_reward(sess, samples, 16, cnn,
                                             batch_reward, D_WEIGHT)
                nll = generator.generator_step(sess, samples, rewards)
                # results
                print_rewards(rewards)
                print('neg-loglike: {}'.format(nll))
                results['neg-loglike'] = nll
            rollout.update_params()

            # generate for discriminator
            print('-> Training Discriminator')
            for i in range(D):
                print('D_Epoch {}'.format(i))
                d_loss, accuracy, ypred = train_discriminator()
                results['D_loss_{}'.format(i)] = d_loss
                results['Accuracy_{}'.format(i)] = accuracy
                results['Mean_ypred_{}'.format(i)] = ypred
            print('results')
            results_rows.append(results)
            if nbatch % params["EPOCH_SAVES"] == 0:
                save_results(sess, PREFIX, PREFIX + '_model_' + str(nbatch),
                             results_rows)

    # write results
        save_results(sess, PREFIX, PREFIX + '_model_' + str(nbatch),
                     results_rows)

    print('\n:*** FINISHED ***')
    return
Exemplo n.º 2
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 68
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    # load generator with parameters
    generator = get_trainable_model(vocab_size)
    target_params = initialize_parameters(vocab_size)

    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                              SEQ_LENGTH, START_TOKEN, target_params)

    # CNNs
    with tf.variable_scope('discriminator'):
        cnn = TextCNN(sequence_length=SEQ_LENGTH,
                      num_classes=2,
                      vocab_size=vocab_size,
                      embedding_size=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [
        param for param in tf.trainable_variables()
        if 'discriminator' in param.name
    ]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss,
                                                         cnn_params,
                                                         aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars,
                                                 global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # generate_samples(sess, target_lstm, 64, 10000, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open(logpath, 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            file_name = 'target_generate/pretrain_epoch' + str(epoch) + '.pkl'
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             file_name)
            likelihood_data_loader.create_batches(file_name)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            log.write(buffer)

            if epoch % 100 != 0:
                os.remove(file_name)

    file_name = 'target_generate/pretrain_finished.pkl'
    generate_samples(sess, generator, BATCH_SIZE, generated_num, file_name)
    likelihood_data_loader.create_batches(file_name)
    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    log.write(buffer)

    file_name = 'target_generate/supervise.pkl'
    generate_samples(sess, generator, BATCH_SIZE, generated_num, file_name)
    likelihood_data_loader.create_batches(file_name)
    significance_test(sess, target_lstm, likelihood_data_loader,
                      'significance/supervise.txt')

    os.remove(file_name)

    print 'Start training discriminator...'
    for i in range(dis_alter_epoch):
        print 'dis_alter_epoch : ' + str(i)
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train),
                                                 dis_batch_size,
                                                 dis_num_epochs)

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except ValueError:
                pass

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss],
                                 feed_dict=feed)

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:

            file_name = 'target_generate/reinforce_batch' + str(
                total_batch) + '.pkl'

            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             file_name)
            likelihood_data_loader.create_batches(file_name)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

            if total_batch % 50 != 0:
                os.remove(file_name)

            if test_loss < best_score:
                best_score = test_loss
                print 'best score: ', test_loss
                significance_test(sess, target_lstm, likelihood_data_loader,
                                  'significance/seqgan.txt')

        rollout.update_params()

        # generate for discriminator
        print 'Start training discriminator'
        for _ in range(5):
            # for _ in range(2):

            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(
                positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(
                zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

    log.close()
Exemplo n.º 3
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    # assert START_TOKEN == 0

    vocab_size = NUM_EMB
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM,
                          HIDDEN_DIM, MAX_LENGTH, START_TOKEN)
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE,
                              EMB_DIM, HIDDEN_DIM, MAX_LENGTH, 0)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(
            sequence_length=MAX_LENGTH,
            num_classes=2,
            vocab_size=vocab_size,
            embedding_size=dis_embedding_dim,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [param for param in tf.trainable_variables()
                  if 'discriminator' in param.name]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(
        cnn.loss, cnn_params, aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(
        dis_grads_and_vars, global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    def train_discriminator():
        if D_WEIGHT == 0:
            return 0, 0

        negative_samples = generate_samples(
            sess, generator, BATCH_SIZE, POSITIVE_NUM)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_samples, negative_samples)
        dis_batches = dis_data_loader.batch_iter(
            zip(dis_x_train, dis_y_train), dis_batch_size, dis_num_epochs
        )

        for batch in dis_batches:
            x_batch, y_batch = zip(*batch)
            feed = {
                cnn.input_x: x_batch,
                cnn.input_y: y_batch,
                cnn.dropout_keep_prob: dis_dropout_keep_prob
            }
            _, step, loss, accuracy = sess.run(
                [dis_train_op, dis_global_step, cnn.loss, cnn.accuracy], feed)
        print('\tD loss  :   {}'.format(loss))
        print('\tAccuracy: {}'.format(accuracy))
        return loss, accuracy

    # Pretrain is checkpointed and only execcutes if we don't find a checkpoint
    saver = tf.train.Saver()
    ckpt_dir = 'checkpoints/{}_pretrain'.format(PREFIX)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    ckpt_file = os.path.join(ckpt_dir, 'pretrain_ckpt')
    if os.path.isfile(ckpt_file + '.meta') and params["LOAD_PRETRAIN"]:
        saver.restore(sess, ckpt_file)
        print('Pretrain loaded from previous checkpoint {}'.format(ckpt_file))
    else:
        sess.run(tf.global_variables_initializer())
        pretrain(sess, generator, target_lstm, train_discriminator)
        path = saver.save(sess, ckpt_file)
        print('Pretrain finished and saved at {}'.format(path))

    # create reward function
    batch_reward = make_reward(train_samples)

    rollout = ROLLOUT(generator, 0.8)

    print('#########################################################################')
    print('Start Reinforcement Training Generator...')
    results_rows = []
    for nbatch in range(TOTAL_BATCH):
        results = OrderedDict({'exp_name': PREFIX})
        if nbatch % 1 == 0 or nbatch == TOTAL_BATCH - 1:
            print('* Making samples')
            if nbatch % 10 == 0:
                gen_samples = generate_samples(
                    sess, generator, BATCH_SIZE, BIG_SAMPLE_NUM)
            else:
                gen_samples = generate_samples(
                    sess, generator, BATCH_SIZE, SAMPLE_NUM)
            likelihood_data_loader.create_batches(gen_samples)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('batch_num: {}'.format(nbatch))
            print('test_loss: {}'.format(test_loss))
            results['Batch'] = nbatch
            results['test_loss'] = test_loss

            if test_loss < best_score:
                best_score = test_loss
                print('best score: %f' % test_loss)

            # results
            mm.compute_results(gen_samples, train_samples, ord_dict, results)

        print('#########################################################################')
        print('-> Training generator with RL.')
        print('G Epoch {}'.format(nbatch))

        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(
                sess, samples, 16, cnn, batch_reward, D_WEIGHT)
            print('Rewards be like...')
            print(rewards)
            nll = generator.generator_step(sess, samples, rewards)

            print('neg-loglike: {}'.format(nll))
            results['neg-loglike'] = nll
        rollout.update_params()

        # generate for discriminator
        print('-> Training Discriminator')
        for i in range(D):
            print('D_Epoch {}'.format(i))
            d_loss, accuracy = train_discriminator()
            results['D_loss_{}'.format(i)] = d_loss
            results['Accuracy_{}'.format(i)] = accuracy
        print('results')
        results_rows.append(results)
        if nbatch % params["EPOCH_SAVES"] == 0:
            save_results(sess, PREFIX, PREFIX + '_model', results_rows)

    # write results
    save_results(sess, PREFIX, PREFIX + '_model', results_rows)

    print('\n:*** FINISHED ***')
    return
Exemplo n.º 4
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 5000
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = get_trainable_model(vocab_size)

    # oracle model : target lstm
    # target_params = cPickle.load(open('save/target_params.pkl'))
    # target_lstm = TARGET_LSTM(vocab_size, 64, 32, 32, SEQ_LENGTH, 0, target_params)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(
            sequence_length=SEQ_LENGTH,
            num_classes=2,
            vocab_size=vocab_size,
            embedding_size=dis_embedding_dim,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [param for param in tf.trainable_variables() if 'discriminator' in param.name]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss, cnn_params, aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars, global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # czq
    # generate real data
    # generate_samples(sess, target_lstm, 64, 10000, positive_file)

    # store real data for next step
    positive_data = np.load(positive_file).tolist()
    gen_data_loader.create_batches(positive_data)

    log = open('log/seq_mle_experiment-log.txt', 'w')
    #  pre-train generator
    print '#########################################################################'
    print 'Start pre-training generator...'
    log.write('pre-training...\n')

    for epoch in xrange(PRE_EPOCH_NUM):
        # print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            # likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            # print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            # buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            buffer = 'pre-trained generator:' + str(epoch) + ' ' + str(loss)
            print(buffer)
            log.write(buffer + '\n')

    # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    # likelihood_data_loader.create_batches(eval_file)
    # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    # buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    buffer = 'After pre-training:' + ' ' + str(loss)
    print(buffer)
    log.write(buffer + '\n')

    # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    # likelihood_data_loader.create_batches(eval_file)
    # significance_test(sess, target_lstm, likelihood_data_loader, 'significance/supervise.txt')

    # test purpose only
    generate_samples(sess, generator, BATCH_SIZE, 100, final_trans_file_mle)

    # exit(0)

    print 'Start pre-training discriminator...'
    for _ in range(dis_alter_epoch):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(
            zip(dis_x_train, dis_y_train), dis_batch_size, dis_num_epochs
        )

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except Exception as e:
                # print str(e)
                raise

        loss = sess.run(cnn.loss, feed)
        buffer = 'pre-train discriminator' + ' ' + str(loss)
        print buffer
        log.write(buffer + '\n')

    rollout = ROLLOUT(generator, 0.8)
    print('Before GAN')
    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    # for tensorboard
    # writer = tf.summary.FileWriter('./tb_logs', graph=tf.get_default_graph())

    for total_batch in range(TOTAL_BATCH):
        print 'progress', total_batch, '/', TOTAL_BATCH
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss, pre_loss = sess.run([generator.g_updates, generator.g_loss, generator.pretrain_loss],
                                           feed_dict=feed)
            buffer = 'G-step:' + str(TRAIN_ITER) + ':' + str(g_loss) + '|' + str(pre_loss)
            log.write(buffer + '\n')
            print(buffer)
            # if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            #     generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            #     likelihood_data_loader.create_batches(eval_file)
            #     # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            #     # buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            #     # print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            #     log.write(buffer)

            # if test_loss < best_score:
            #     best_score = test_loss
            #     print 'best score: ', test_loss
            #     significance_test(sess, target_lstm, likelihood_data_loader, 'significance/seqgan.txt')

        rollout.update_params()

        # generate for discriminator
        print('Start training discriminator')
        log.write('training discriminator...\n')

        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

            loss = sess.run(cnn.loss, feed)
            buffer = 'discriminator' + ' ' + str(loss)
            print buffer
            log.write(buffer + '\n')

    log.close()

    # save the model
    # saver = tf.train.Saver({"gen": generator})
    # saver.save(sess, 'my-model')

    # generate samples
    generate_samples(sess, generator, BATCH_SIZE, 100, final_trans_file_seqgan)
Exemplo n.º 5
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    stringGenerator = TextGenerator('../corpus/index2word.pickle',
                                    '../corpus/word2index.pickle',
                                    '../corpus/all.code')

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = len(stringGenerator.index2Word)
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = get_trainable_model(vocab_size)
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_params[00] = np.random.rand(vocab_size, 32).astype(np.float32)
    target_params[-2] = np.random.rand(32, vocab_size).astype(np.float32)
    target_params[-1] = np.random.rand(vocab_size).astype(np.float32)
    target_lstm = TARGET_LSTM(vocab_size, 64, 32, 32, 20, 0, target_params)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(sequence_length=20,
                      num_classes=2,
                      vocab_size=vocab_size,
                      embedding_size=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [
        param for param in tf.trainable_variables()
        if 'discriminator' in param.name
    ]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss,
                                                         cnn_params,
                                                         aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars,
                                                 global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    #generate_samples(sess, target_lstm, 64, 10000, positive_file)
    stringGenerator.saveSamplesToFile(20, 10000, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('log/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    significance_test(sess, target_lstm, likelihood_data_loader,
                      'significance/supervise.txt')

    print 'Start training discriminator...'
    for _ in range(dis_alter_epoch):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train),
                                                 dis_batch_size,
                                                 dis_num_epochs)

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except ValueError:
                pass

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss],
                                 feed_dict=feed)

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

            if test_loss < best_score:
                best_score = test_loss
                print 'best score: ', test_loss
                significance_test(sess, target_lstm, likelihood_data_loader,
                                  'significance/seqgan.txt')

        rollout.update_params()

        # generate for discriminator
        print 'Start training discriminator'
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(
                positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(
                zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

    log.close()
Exemplo n.º 6
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 5000
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = get_trainable_model(vocab_size)
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(vocab_size, 64, 32, 32, 20, 0, target_params)

    with tf.variable_scope('discriminator'):
        cnn = TextCNN(
            sequence_length=20,
            num_classes=2,
            vocab_size=vocab_size,
            embedding_size=dis_embedding_dim,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [param for param in tf.trainable_variables() if 'discriminator' in param.name]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss, cnn_params, aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars, global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer)

    generate_samples(sess, target_lstm, 64, 10000, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('log/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    significance_test(sess, target_lstm, likelihood_data_loader, 'significance/supervise.txt')

    print 'Start training discriminator...'
    for _ in range(dis_alter_epoch):
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(
            zip(dis_x_train, dis_y_train), dis_batch_size, dis_num_epochs
        )

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except ValueError:
                pass

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss], feed_dict=feed)

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

            if test_loss < best_score:
                best_score = test_loss
                print 'best score: ', test_loss
                significance_test(sess, target_lstm, likelihood_data_loader, 'significance/seqgan.txt')

        rollout.update_params()

        # generate for discriminator
        print 'Start training discriminator'
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

    log.close()
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    # assert START_TOKEN == 0

    vocab_size = NUM_EMB
    dis_data_loader = Dis_dataloader()

    best_score = 1000
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM,
                          HIDDEN_DIM, MAX_LENGTH, START_TOKEN)
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE,
                              EMB_DIM, HIDDEN_DIM, MAX_LENGTH, 0)

    with tf.variable_scope('discriminator'):

#	batch_size=dis_batch_size
	cnn = GraphConvTensorGraph(n_tasks=1, batch_size=dis_batch_size, mode='classification')
	if not cnn.built:
	      cnn.build()

#indentation different 2 spaces
#  with cnn._get_tf("Graph").as_default():
#    manager = cnn._get_tf("Graph").as_default()
#    manager.__enter__()
    # Define Discriminator Training procedure
#    train_op = cnn._get_tf('train_op')

  # Define Discriminator Training procedure
    cnn_params = [param for param in tf.trainable_variables() if 'discriminator' in param.name]

    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(
        cnn.loss.out_tensor, cnn_params, aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(
        dis_grads_and_vars, global_step=dis_global_step)



    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    def train_discriminator():

	output_tensors = [cnn.outputs[0].out_tensor] #added from deepchem
    	fetches =  output_tensors + [dis_train_op, cnn.loss.out_tensor] 
        if D_WEIGHT == 0:
            return 0, 0

        negative_samples = generate_samples(
            sess, generator, BATCH_SIZE, POSITIVE_NUM)

        #  train discriminator
	#pos_smiles= [mm.decode(sample, ord_dict) for sample in positive_samples] #already defined in intro
	neg_smiles= [mm.decode(sample, ord_dict) for sample in negative_samples]
	#df_pos=pd.DataFrame({'real/fake': [1]*len(pos_smiles)   , 'smiles': pos_smiles}) #already defined in intro
	df_neg=pd.DataFrame({'real/fake': [0]*len(neg_smiles)   , 'smiles': neg_smiles}) 
	#df_total=df_pos.append(df_neg)
	df_total=pd.concat([df_pos,df_neg], ignore_index=True)

	
	def shuffle_by_batches(df, batch_size=dis_batch_size):
		l=len(df)
		l_batches=range(l)    
		l_batches=[l_batches[x*batch_size:(x+1)*batch_size] for x in range(np.ceil(l/batch_size))]
		permut=np.random.permutation(len(l_batches))
		l_new=[l_batches[n] for n in permut]
		l_new = [item for sublist in l_new for item in sublist] #flatten l_new
		df=df.reindex(np.array(l_new))
		df=df.reset_index(drop=True)   
		return df

	df_total=shuffle_by_batches(df_total,batch_size=dis_batch_size)    #shuffle rows by batch, see: Tip 4: https://github.com/soumith/ganhacks




	loader = deepchem.data.PandasLoader(tasks=['real/fake'], smiles_field="smiles", featurizer=deepchem.feat.ConvMolFeaturizer())
	train_dataset = loader.featurize(df_total, shard_size=8192)
	feed_dict_generator=cnn.default_generator(train_dataset, epochs=1)


	def create_feed_dict():
	      for d in feed_dict_generator:
	        feed_dict = {k.out_tensor: v for k, v in six.iteritems(d)}
	        feed_dict[cnn._training_placeholder] = 1.0
	        yield feed_dict

	def select_label(feed_dict):  #Select the output label layer in the feed dictionaty
		newkeys=[]
		for k in feed_dict.keys():
	             if 'Label' in k.name:
	                 newkeys.append(k)
		return newkeys[0]


	avg_loss, avg_acc, n_batches = 0.0, 0.0, 0

	for feed_dict in create_feed_dict():
	     fetched_values = sess.run(fetches, feed_dict=feed_dict)
	     loss = fetched_values[-1]
	     predicted_results=[np.argmax(x) for x in fetched_values[0]]
	     ground_truth= [np.argmax(x) for x in feed_dict[select_label(feed_dict)] ]
	     results= [ predicted_results[n] ==ground_truth[n] for n in range(len(predicted_results)) ]
	     accuracy= float(sum(results))/len(results)
	     n_batches += 1
	     ratio= (1/float(n_batches))
	     avg_loss= (1-ratio)*avg_loss + ratio*loss
	     avg_acc= (1-ratio)*avg_acc + ratio*accuracy
        print('\tD loss  :   {}'.format(avg_loss))
        print('\tAccuracy: {}'.format(avg_acc))
        return avg_loss, avg_acc

    # Pretrain is checkpointed and only execcutes if we don't find a checkpoint
    saver = tf.train.Saver()
    ckpt_dir = 'checkpoints_new/{}_pretrain'.format(PREFIX)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    ckpt_file = os.path.join(ckpt_dir, 'pretrain_ckpt')
    if os.path.isfile(ckpt_file + '.meta') and params["LOAD_PRETRAIN"]:
        saver.restore(sess, ckpt_file)
        print('Pretrain loaded from previous checkpoint {}'.format(ckpt_file))
    else:
        if params["LOAD_PRETRAIN"]:
            print('\t* No pre-training data found as {:s}.'.format(ckpt_file))
        else:
            print('\t* LOAD_PRETRAIN was set to false.')
  	cnn._initialize_weights(sess, saver) #added from deepchem
        sess.run(tf.global_variables_initializer())
        pretrain(sess, generator, target_lstm, train_discriminator)
        path = saver.save(sess, ckpt_file)
        print('Pretrain finished and saved at {}'.format(path))

    # create reward function
    batch_reward = make_reward(train_samples)

    rollout = ROLLOUT(generator, 0.8)

    print('#########################################################################')
    print('Start Reinforcement Training Generator...')
    results_rows = []
    for nbatch in tqdm(range(TOTAL_BATCH)):
        results = OrderedDict({'exp_name': PREFIX})
        if nbatch % 1 == 0 or nbatch == TOTAL_BATCH - 1:
            print('* Making samples')
            if nbatch % 10 == 0:
                gen_samples = generate_samples(
                    sess, generator, BATCH_SIZE, BIG_SAMPLE_NUM)
            else:
                gen_samples = generate_samples(
                    sess, generator, BATCH_SIZE, SAMPLE_NUM)
            likelihood_data_loader.create_batches(gen_samples)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('batch_num: {}'.format(nbatch))
            print('test_loss: {}'.format(test_loss))
            results['Batch'] = nbatch
            results['test_loss'] = test_loss

            if test_loss < best_score:
                best_score = test_loss
                print('best score: %f' % test_loss)

            # results
            mm.compute_results(gen_samples, train_samples, ord_dict, results)

        print('#########################################################################')
        print('-> Training generator with RL.')
        print('G Epoch {}'.format(nbatch))

        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(
                sess, samples, 16, cnn,  ord_dict, batch_reward, D_WEIGHT)
            nll = generator.generator_step(sess, samples, rewards)
            # results
            print_rewards(rewards)
            print('neg-loglike: {}'.format(nll))
            results['neg-loglike'] = nll
        rollout.update_params()

        # generate for discriminator
        print('-> Training Discriminator')
        for i in range(D):
            print('D_Epoch {}'.format(i))
            d_loss, accuracy = train_discriminator()
            results['D_loss_{}'.format(i)] = d_loss
            results['Accuracy_{}'.format(i)] = accuracy
        print('results')
        results_rows.append(results)
        if nbatch % params["EPOCH_SAVES"] == 0:
            save_results(sess, PREFIX, PREFIX + '_model', results_rows)

    # write results
    save_results(sess, PREFIX, PREFIX + '_model', results_rows)

    print('\n:*** FINISHED ***')
    return
Exemplo n.º 8
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Likelihood_data_loader(BATCH_SIZE)
    vocab_size = 5000
    dis_data_loader = Dis_dataloader()

    best_score = 1000

    # initialize a LSTM object and use the LSTM object to initialize PoemGen object
    generator = get_trainable_model(vocab_size)

    # cPickle is a object serialization library

    # the loaded picle object will be an array of numbers
    # later, these params will be used to initalize the target LSTM
    target_params = cPickle.load(open('save/target_params.pkl'))
    # print target_params

    time.sleep(1000)

    # This seems like the generator model which used RNN
    target_lstm = TARGET_LSTM(vocab_size, 64, 32, 32, 20, 0, target_params)

    # This is the discriminator which uses CNN
    with tf.variable_scope('discriminator'):
        cnn = TextCNN(sequence_length=20,
                      num_classes=2,
                      vocab_size=vocab_size,
                      embedding_size=dis_embedding_dim,
                      filter_sizes=dis_filter_sizes,
                      num_filters=dis_num_filters,
                      l2_reg_lambda=dis_l2_reg_lambda)

    cnn_params = [
        param for param in tf.trainable_variables()
        if 'discriminator' in param.name
    ]
    # Define Discriminator Training procedure
    dis_global_step = tf.Variable(0, name="global_step", trainable=False)
    dis_optimizer = tf.train.AdamOptimizer(1e-4)
    dis_grads_and_vars = dis_optimizer.compute_gradients(cnn.loss,
                                                         cnn_params,
                                                         aggregation_method=2)
    dis_train_op = dis_optimizer.apply_gradients(dis_grads_and_vars,
                                                 global_step=dis_global_step)

    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction = 0.5
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.initialize_all_variables())

    generate_samples(sess, target_lstm, 64, 10000, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('log/experiment-log.txt', 'w')
    #  pre-train generator
    # Initialize the generator with MLE estimators
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        print 'pre-train epoch:', epoch
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = str(epoch) + ' ' + str(test_loss) + '\n'
            log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
    buffer = 'After pre-training:' + ' ' + str(test_loss) + '\n'
    log.write(buffer)

    generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
    likelihood_data_loader.create_batches(eval_file)
    significance_test(sess, target_lstm, likelihood_data_loader,
                      'significance/supervise.txt')

    print 'Start training discriminator...'
    for _ in range(dis_alter_epoch):
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)

        #  train discriminator
        dis_x_train, dis_y_train = dis_data_loader.load_train_data(
            positive_file, negative_file)
        dis_batches = dis_data_loader.batch_iter(zip(dis_x_train, dis_y_train),
                                                 dis_batch_size,
                                                 dis_num_epochs)

        for batch in dis_batches:
            try:
                x_batch, y_batch = zip(*batch)
                feed = {
                    cnn.input_x: x_batch,
                    cnn.input_y: y_batch,
                    cnn.dropout_keep_prob: dis_dropout_keep_prob
                }
                _, step = sess.run([dis_train_op, dis_global_step], feed)
            except ValueError:
                pass

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Reinforcement Training Generator...'
    log.write('Reinforcement Training...\n')

    for total_batch in range(TOTAL_BATCH):
        for it in range(TRAIN_ITER):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, 16, cnn)
            feed = {generator.x: samples, generator.rewards: rewards}
            _, g_loss = sess.run([generator.g_updates, generator.g_loss],
                                 feed_dict=feed)

        if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
            # The trainable model 'generator' is a RNN model from PoemGen

            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = str(total_batch) + ' ' + str(test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

            if test_loss < best_score:
                best_score = test_loss
                print 'best score: ', test_loss
                significance_test(sess, target_lstm, likelihood_data_loader,
                                  'significance/seqgan.txt')

        # rollout policy???
        rollout.update_params()

        # generate for discriminator
        print 'Start training discriminator'
        for _ in range(5):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)

            dis_x_train, dis_y_train = dis_data_loader.load_train_data(
                positive_file, negative_file)
            dis_batches = dis_data_loader.batch_iter(
                zip(dis_x_train, dis_y_train), dis_batch_size, 3)

            for batch in dis_batches:
                try:
                    x_batch, y_batch = zip(*batch)
                    feed = {
                        cnn.input_x: x_batch,
                        cnn.input_y: y_batch,
                        cnn.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _, step = sess.run([dis_train_op, dis_global_step], feed)
                except ValueError:
                    pass

    log.close()