Exemple #1
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE,
                                             SEQ_LENGTH)  # For testing
    vocab_size = 5000

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH,
                              START_TOKEN, target_params)  # The oracle model

    mediator = Generator(vocab_size,
                         BATCH_SIZE * 2,
                         EMB_DIM * 2,
                         HIDDEN_DIM * 2,
                         SEQ_LENGTH,
                         START_TOKEN,
                         name="mediator",
                         dropout_rate=M_DROPOUT_RATE)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num,
                     positive_file)
    gen_data_loader.create_batches(positive_file)
    generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file)
    val_data_loader.create_batches(eval_file)

    log = open('save/experiment-log.txt', 'w')
    log_nll = open('save/experiment-log-nll.txt', 'w')
    log_jsd = open('save/experiment-log-jsd.txt', 'w')
    #  pre-train generator (default 0 epochs)(not recommended)
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = mle_epoch(sess, generator, gen_data_loader)
        if epoch % 1 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            likelihood_data_loader.create_batches(negative_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('pre-train epoch ', epoch, 'nll_oracle ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll_oracle:\t' + str(
                test_loss) + '\n'
            log_nll.write(buffer)
        if epoch % 1 == 0:
            test_loss = target_loss(sess, generator, val_data_loader)
            print('pre-train epoch ', epoch, 'nll_test ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll_test:\t' + str(
                test_loss) + '\n'
            log_nll.write(buffer)

    print(
        '#########################################################################'
    )
    print('Start Cooperative Training...')
    for iter_idx in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(2):
            samples = generator.generate(sess)
            rewards = mediator.get_reward(
                sess, np.concatenate([samples, samples], axis=0))
            feed = {
                generator.x: samples,
                generator.rewards: rewards[0:BATCH_SIZE]
            }
            _ = sess.run(generator.g_updates, feed_dict=feed)
        # Test
        if iter_idx % 100 == 0 or iter_idx == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            likelihood_data_loader.create_batches(negative_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'batch:\t' + str(iter_idx) + '\tnll_oracle:\t' + str(
                test_loss) + '\n'
            print('batch: ', iter_idx, 'nll_oracle: ', test_loss)
            log_nll.write(buffer)
        if iter_idx % 100 == 0:
            test_loss = target_loss(sess, generator, val_data_loader)
            print('batch:\t', iter_idx, 'nll_test ', test_loss)
            buffer = 'batch:\t' + str(iter_idx) + '\tnll_test:\t' + str(
                test_loss) + '\n'
            log_nll.write(buffer)
        # Train the mediator
        for _ in range(1):
            bnll_ = []
            collected_x = []
            ratio = 2
            for it in range(ratio):
                if it % 2 == 0:
                    x_batch = gen_data_loader.next_batch()
                else:
                    x_batch = generator.generate(sess)
                collected_x.append(x_batch)
            collected_x = np.reshape(collected_x, [-1, SEQ_LENGTH])
            np.random.shuffle(collected_x)
            collected_x = np.reshape(collected_x,
                                     [-1, BATCH_SIZE * 2, SEQ_LENGTH])
            for it in range(1):
                feed = {
                    mediator.x: collected_x[it],
                }
                bnll = sess.run(mediator.likelihood_loss, feed)
                bnll_.append(bnll)
                # sess.run(mediator.dropout_on)
                _ = sess.run(mediator.likelihood_updates, feed)
                # sess.run(mediator.dropout_off)
        if (iter_idx * 4) % gen_data_loader.num_batch == 0:
            bnll = np.mean(bnll_)
            gnll = sess.run(
                mediator.likelihood_loss,
                feed_dict={
                    mediator.x:
                    np.reshape(
                        [generator.generate(sess),
                         generator.generate(sess)],
                        [BATCH_SIZE * 2, SEQ_LENGTH])
                })
            print("mediator cooptrain iter#%d, balanced_nll %f, g_nll %f" %
                  (iter_idx, bnll, gnll))
            log.write("%d\t%f\n" % (iter_idx, bnll))
        if iter_idx % gen_data_loader.num_batch == 0:
            jsd = jsd_calculate(sess, generator, target_lstm)
            print('cooptrain epoch#', iter_idx // gen_data_loader.num_batch,
                  'jsd ', jsd)
            log_jsd.write("%d\t%f\n" %
                          (iter_idx // gen_data_loader.num_batch, jsd))

    log.close()
    log_nll.close()
    log_jsd.close()
Exemple #2
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()
    np.random.seed(config_train.seed)
    assert config_train.start_token == 0

    #Build dataloader for generaotr, testing and discriminator
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    #Build generator and its rollout
    generator = Generator(config=config_gen)
    generator.build()
    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = cPickle.load(StrToBytes(open('save/target_params.pkl')),
                                 encoding='bytes')
    target_lstm = TARGET_LSTM(config=config_gen,
                              params=target_params)  # The oracle model

    #Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    #Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(
        config_train.gen_learning_rate)
    var_pretrained = [
        v for v in tf.trainable_variables() if 'teller' in v.name
    ]  #Using name 'teller' here to prevent name collision of target LSTM
    gradients, variables = zip(*pretrained_optimizer.compute_gradients(
        generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_upate = pretrained_optimizer.apply_gradients(
        zip(gradients, variables))

    #Initialize all variables
    sess = tf.Session(config=config_hardware)
    sess.run(tf.global_variables_initializer())

    #Initalize data loader of generator
    # generate_samples(sess, target_lstm, config_train.batch_size, config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    #Start pretraining
    log = open('save/experiment-log.txt', 'w')
    print('Start pre-training generator...')
    log.write('pre-training...\n')
    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\
                                                                                    generator.input_seqs_mask:np.ones_like(batch)})
        if epoch % config_train.test_per_epoch == 0:
            # generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess, generator, config_train.batch_size,
                         config_train.generated_num,
                         config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,
                                        config_train.negative_file)
        for _ in range(config_train.dis_update_epoch_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x:
                    x_batch,
                    discriminator.input_y:
                    y_batch,
                    discriminator.dropout_keep_prob:
                    config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    #Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(
        generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    #Initialize global variables of optimizer for adversarial training
    uninitialized_var = [
        e for e in tf.global_variables() if e not in tf.trainable_variables()
    ]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    #Start adversarial training
    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshape)

            feed = {"pred_seq_rollout:0": samples}
            reward_rollout = []
            #calcuate the reward given in the specific stpe t by roll out
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,
                                        feed_dict=feed)
                rollout_list_stack = np.vstack(
                    rollout_list
                )  #shape: #batch_size * #rollout_step, #sequence length
                reward_rollout_seq = sess.run(
                    discriminator.ypred_for_auc,
                    feed_dict={
                        discriminator.input_x: rollout_list_stack,
                        discriminator.dropout_keep_prob: 1.0
                    })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,
                                           feed_dict={
                                               discriminator.input_x: samples,
                                               discriminator.dropout_keep_prob:
                                               1.0
                                           })
                reward_allseq = np.concatenate(
                    (reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(
                        r,
                        config_gen.gen_batch_size * config_gen.sequence_length,
                        config_gen.gen_batch_size)])
                reward_rollout.append(np.array(reward_tmp))
            rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num
            _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\
                                                                                        generator.rewards:rewards})
        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)

        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,
                                            config_train.negative_file)

            for _ in range(config_train.dis_update_epoch_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:
                        x_batch,
                        discriminator.input_y:
                        y_batch,
                        discriminator.dropout_keep_prob:
                        config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
    log.close()
Exemple #3
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()
    np.random.seed(config_train.seed)
    assert config_train.start_token == 0

    #Build dataloader for generaotr, testing and discriminator
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    #Build generator and its rollout
    generator = Generator(config=config_gen)
    # 生成 3个神经网络
    generator.build()
    #  快速展开网络,序列未生成完就预测后边的序列,用于计算reward
    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(config=config_gen,
                              params=target_params)  # The oracle model

    #Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    #Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(
        config_train.gen_learning_rate)
    # 取出 teller 的所有变量, teller在 generator和rollout网络中
    var_pretrained = [
        v for v in tf.trainable_variables() if 'teller' in v.name
    ]  #Using name 'teller' here to prevent name collision of target LSTM
    # zip函数将 2个迭代器  组成tuple
    gradients, variables = zip(*pretrained_optimizer.compute_gradients(
        generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_upate = pretrained_optimizer.apply_gradients(
        zip(gradients, variables))

    #Initialize all variables
    sess = tf.Session(config=config_hardware)
    sess.run(tf.global_variables_initializer())

    #Initalize data loader of generator   utils.py文件中
    #   target_lstm 网络生成真实数据 写入config_train.positive_file 文件
    generate_samples(sess, target_lstm, config_train.batch_size,
                     config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    #Start pretraining
    log = open('save/experiment-log.txt', 'w')
    print 'Start pre-training generator...'
    log.write('pre-training...\n')
    for epoch in xrange(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in xrange(gen_data_loader.num_batch):
            #见第60行,加载target_lstm 神经网络的数据,用于预训练生成器====真实样本
            batch = gen_data_loader.next_batch()
            #真实数据训练  generator;有监督学习   batch 最后第一个是label
            _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\
                                                                                    generator.input_seqs_mask:np.ones_like(batch)})
        if epoch % config_train.test_per_epoch == 0:
            #  generator 生成样本  与 真实数据的相似度
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            #评估生成质量
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print 'Start pre-training discriminator...'
    for t in range(config_train.dis_update_time_pre):
        print "Times: " + str(t)
        #   generator生成假数据+ target_lstm的真实数据;; 用于训练
        generate_samples(sess, generator, config_train.batch_size,
                         config_train.generated_num,
                         config_train.negative_file)
        #  混合真假数据
        dis_data_loader.load_train_data(config_train.positive_file,
                                        config_train.negative_file)
        for _ in range(config_train.dis_update_epoch_pre):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x:
                    x_batch,
                    discriminator.input_y:
                    y_batch,
                    discriminator.dropout_keep_prob:
                    config_dis.dis_dropout_keep_prob
                }
                #交叉上最小;  主要是训练评分网络 用于给generator提供reward
                _ = sess.run(discriminator.train_op, feed)

    #Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(
        generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    #Initialize global variables of optimizer for adversarial training
    uninitialized_var = [
        e for e in tf.global_variables() if e not in tf.trainable_variables()
    ]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    #Start adversarial training   开始对抗训练
    for total_batch in xrange(config_train.total_batch):
        for iter_gen in xrange(config_train.gen_update_time):

            #  用generator进行抽样; LSTM 生成序列
            samples = sess.run(generator.sample_word_list_reshape)

            feed = {"pred_seq_rollout:0": samples}
            reward_rollout = []
            #calcuate the reward given in the specific stpe t by roll out
            # 用rollout网络计算指定动作的回报
            for iter_roll in xrange(config_train.rollout_num):

                # 生成器采样的获得的单词传给 rollout  ??有一个疑问?samples看代码是完整序列(与论文不符),为什么还要rollout
                rollout_list = sess.run(rollout_gen.sample_rollout_step,
                                        feed_dict=feed)

                rollout_list_stack = np.vstack(
                    rollout_list
                )  #shape: #batch_size * #rollout_step, #sequence length
                # 蒙特卡洛 展开成序列,贝尔曼方程计算 reward
                reward_rollout_seq = sess.run(
                    discriminator.ypred_for_auc,
                    feed_dict={
                        discriminator.input_x: rollout_list_stack,
                        discriminator.dropout_keep_prob: 1.0
                    })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,
                                           feed_dict={
                                               discriminator.input_x: samples,
                                               discriminator.dropout_keep_prob:
                                               1.0
                                           })
                reward_allseq = np.concatenate(
                    (reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in xrange(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(
                        r,
                        config_gen.gen_batch_size * config_gen.sequence_length,
                        config_gen.gen_batch_size)])
                reward_rollout.append(np.array(reward_tmp))
            #计算reward
            rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num
            # 用reward 指导 generator 更新梯度
            _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\
                                                                                        generator.rewards:rewards})
        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            #对抗训练后 用generator再次生成样本与模拟器(target_lstm,真实数据)进行比对
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            #util.py中定义
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,
                                            config_train.negative_file)

            for _ in range(config_train.dis_update_epoch_adv):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:
                        x_batch,
                        discriminator.input_y:
                        y_batch,
                        discriminator.dropout_keep_prob:
                        config_dis.dis_dropout_keep_prob
                    }
                    #训练这个评分网络, score
                    _ = sess.run(discriminator.train_op, feed)
    log.close()
Exemple #4
0
def main():

    # load embedding info
    vocab_dict, vocab_size, vocab_list = load_emb_data(emb_dict_file)

    # prepare data
    pre_train_data_loader = Gen_Data_loader(BATCH_SIZE, vocab_dict)
    pre_train_data_loader.create_batches(
        [imdb_file_id, sst_pos_file_id, sst_neg_file_id])

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, vocab_dict)
    gen_data_loader.create_batches([sst_pos_file_id, sst_neg_file_id])

    dis_data_loader = Dis_Data_loader(BATCH_SIZE, vocab_dict, MAX_SEQ_LENGTH)

    # build model
    # num_emb, vocab_dict, batch_size, emb_dim, num_units, sequence_length
    generator = Generator(vocab_size, vocab_dict, BATCH_SIZE, EMB_DIM,
                          HIDDEN_DIM, MAX_SEQ_LENGTH)
    discriminator = Discriminator(sequence_length=MAX_SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    log = open('save/experiment-log.txt', 'w')

    buffer = 'Start pre-training generator...'
    print(buffer)
    log.write(buffer + '\n')
    for epoch in range(150):  #120
        train_loss = pre_train_epoch(sess, generator, pre_train_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess,
                             generator,
                             1,
                             eval_file,
                             vocab_list,
                             if_log=True,
                             epoch=epoch)
            print('    pre-train epoch ', epoch, 'train_loss ', train_loss)
            buffer = '    epoch:\t' + str(epoch) + '\tnll:\t' + str(
                train_loss) + '\n'
            log.write(buffer)

    buffer = 'Start pre-training discriminator...'
    print(buffer)
    log.write(buffer)
    for _ in range(10):  # 10
        generate_samples(sess, generator, 70, negative_file, vocab_list)
        dis_data_loader.load_train_data([sst_pos_file_id, sst_neg_file_id],
                                        [negative_file])
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob,
                }
                d_loss, d_acc, _ = sess.run([
                    discriminator.loss, discriminator.accuracy,
                    discriminator.train_op
                ], feed)
        buffer = "discriminator loss %f acc %f" % (d_loss, d_acc)
        print(buffer)
        log.write(buffer + '\n')

    print("Start Adversarial Training...")
    log.write('adversarial training...')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator
        for it in range(2):
            # print("1")
            samples = generator.generate(sess)
            samples = produce_samples(samples)
            # print("2")
            rewards = generator.get_reward(sess, samples, 16, discriminator)
            # print("3")
            a = str(samples[0])
            b = str(rewards[0])
            # rewards = change_rewards(rewards)
            # c = str(rewards[0])
            d = build_from_ids(samples[0], vocab_list)
            buffer = "%s\n%s\n%s\n\n" % (d, a, b)
            print(buffer)
            log.write(buffer)

            # print("4")
            rewards_loss = generator.update_with_rewards(
                sess, samples, rewards)
            # print("5")
            # good rewards
            # good_samples = gen_data_loader.next_batch()
            # rewards = np.array([[0.0001] * SEQ_LENGTH] * BATCH_SIZE)
            # a = str(good_samples[0])
            # b = str(rewards[0])
            # buffer = "%s\n%s\n\n" % (a, b)
            # print(buffer)
            # log.write(buffer)
            # rewards_loss = generator.update_with_rewards(sess, good_samples, rewards, START_TOKEN)

            # little1 good reward
            little1_samples = gen_data_loader.next_batch()
            rewards = generator.get_reward(sess, little1_samples, 16,
                                           discriminator)
            a = str(little1_samples[0])
            b = str(rewards[0])
            buffer = "%s\n%s\n\n" % (a, b)
            # print(buffer)
            log.write(buffer)
            rewards_loss = generator.update_with_rewards(
                sess, little1_samples, rewards)

        # generate_infer(sess, generator, epoch, vocab_list)

        # Test
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess,
                             generator,
                             120,
                             eval_file,
                             vocab_list,
                             if_log=True)
            generate_infer(sess, generator, total_batch, vocab_list)
            # generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)
            # likelihood_data_loader.create_batches(eval_file)
            # test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'reward-train epoch %s train loss %s' % (
                str(total_batch), str(rewards_loss))
            print(buffer)
            log.write(buffer + '\n')
            generator.save_model(sess)

        # Train the discriminator
        begin = True
        for _ in range(1):
            generate_samples(sess, generator, 70, negative_file, vocab_list)
            dis_data_loader.load_train_data([sst_pos_file_id, sst_neg_file_id],
                                            [negative_file])
            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob,
                    }
                    d_loss, d_acc, _ = sess.run([
                        discriminator.loss, discriminator.accuracy,
                        discriminator.train_op
                    ], feed)
                    if (total_batch % 5 == 0
                            or total_batch == TOTAL_BATCH - 1) and begin:
                        buffer = "discriminator loss %f acc %f\n" % (d_loss,
                                                                     d_acc)
                        print(buffer)
                        log.write(buffer)
                        begin = False

        # pretrain
        for _ in range(10):
            train_loss = pre_train_epoch(sess, generator,
                                         pre_train_data_loader)
Exemple #5
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)

    if os.path.exists(DICO_PKL):
        with open(DICO_PKL, 'rb') as f:
            word_to_id, id_to_word = pickle.load(f)
    else:
        word_to_id, id_to_word = create_dico(DICO)
        with open(DICO_PKL, 'wb') as f:
            pickle.dump([word_to_id, id_to_word], f)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, word_to_id)
    dis_data_loader = Dis_Data_loader(BATCH_SIZE, word_to_id)
    vocab_size = len(word_to_id)
    assert START_TOKEN == word_to_id['sos']

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    discriminator = BLEUCNN(SEQ_LENGTH, 2, EMB_DIM, generator)
    mobilenet = MobileNet(BATCH_SIZE)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    mobilenet.load_pretrained_weights(sess)
    sess.run(tf.global_variables_initializer())

    log = open('experiment-log.txt', 'w', encoding='utf-8')
    #  pre-train generator and discriminator
    log.write('pre-training...\n')
    print('Start pre-training discriminator...')
    datas = create_data(DICO, word_to_id)
    gen_data_loader.create_batches(CORPUS, IMAGE)
    samples = []
    for it in range(gen_data_loader.num_batch):
        inp_batch, image_batch = gen_data_loader.next_batch()
        feed_dict = {mobilenet.X: image_batch, mobilenet.is_training: False}
        hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
        samples.extend(generator.generate(sess, hidden_batch).tolist())
    dis_data_loader.create_batches(random.sample(datas, 3000), samples)
    for _ in range(PRE_EPOCH_NUM):
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch):
            x_batch, labels = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.labels: labels,
                discriminator.dropout_keep_prob: 0.75
            }
            _ = sess.run(discriminator.train_op, feed)

    print('Start pre-training generator...')
    for epoch in range(PRE_EPOCH_NUM):
        supervised_g_losses = []
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            inp_batch, image_batch = gen_data_loader.next_batch()
            feed_dict = {
                mobilenet.X: image_batch,
                mobilenet.is_training: False
            }
            hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
            _, g_loss = generator.pretrain_step(sess, inp_batch, hidden_batch)
            supervised_g_losses.append(g_loss)
        loss = np.mean(supervised_g_losses)
        if epoch % 5 == 0:
            print('pre-train epoch ', epoch, 'train_loss ', loss)
            buffer = 'epoch:\t' + str(epoch) + '\ttrain_loss:\t' + str(
                loss) + '\n'
            log.write(buffer)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start REINFORCE Training...')
    log.write('REINFORCE training...\n')
    for total_batch in range(RL_EPOCH_NUM):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            ra = random.randint(0, 1)
            inp_batch, image_batch = gen_data_loader.next_batch(shuffle=ra)
            feed_dict = {
                mobilenet.X: image_batch,
                mobilenet.is_training: False
            }
            hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
            samples = generator.generate(sess, hidden_batch)
            rewards = rollout.get_reward(sess, samples, hidden_batch, 16,
                                         discriminator)
            feed = {
                generator.x: inp_batch,
                generator.rewards: rewards,
                generator.hiddens: hidden_batch
            }
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        if total_batch % 5 == 0 or total_batch == RL_EPOCH_NUM - 1:
            mean_rewards = []
            gen_data_loader.reset_pointer()
            for it in range(gen_data_loader.num_batch):
                inp_batch, image_batch = gen_data_loader.next_batch()
                feed_dict = {
                    mobilenet.X: image_batch,
                    mobilenet.is_training: False
                }
                hidden_batch = sess.run(mobilenet.y_output,
                                        feed_dict=feed_dict)
                samples = generator.generate(sess, hidden_batch)
                rewards = rollout.get_reward(sess, samples, hidden_batch, 16,
                                             discriminator)
                mean_rewards.append(np.mean(rewards[:, -1]))
            reward = np.mean(mean_rewards)
            buffer = 'epoch:\t' + str(total_batch) + '\treward:\t' + str(
                reward) + '\n'
            print('total_batch: ', total_batch, 'reward: ', reward)
            log.write(buffer)
            generator.save_weight(sess)

        # Update roll-out parameters
        rollout.update_params()
        discriminator.update_embedding()

        # Train the discriminator
        samples = []
        for it in range(gen_data_loader.num_batch):
            inp_batch, image_batch = gen_data_loader.next_batch()
            feed_dict = {
                mobilenet.X: image_batch,
                mobilenet.is_training: False
            }
            hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
            samples.extend(generator.generate(sess, hidden_batch).tolist())
        dis_data_loader.create_batches(random.sample(datas, 3000), samples)
        dis_data_loader.reset_pointer()
        for it in range(dis_data_loader.num_batch):
            x_batch, labels = dis_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.labels: labels,
                discriminator.dropout_keep_prob: 0.75
            }
            _ = sess.run(discriminator.train_op, feed)

    # final test
    gen_data_loader.reset_pointer()
    _, image_batch = gen_data_loader.next_batch()
    feed_dict = {mobilenet.X: image_batch, mobilenet.is_training: False}
    hidden_batch = sess.run(mobilenet.y_output, feed_dict=feed_dict)
    samples = generator.generate(sess, hidden_batch)
    y = samples.tolist()
    sams = []
    for k, sam in enumerate(y):
        sa = [id_to_word[i] for i in sam]
        sa = ''.join(sa)
        sams.append(sa)
    for sam in sams:
        log.write(sam + '\n')
    log.close()
Exemple #6
0
TEST BEGIN @3.29
TEST 1     @4.18
'''

print(
    '#########################################################################'
)
print('Start Adversarial Training...')
log.write('adversarial training...\n')
sampel_log = open('save/sample-log.txt', 'w')
gen_data_loader.reset_pointer()
for total_batch in range(TOTAL_BATCH):
    # Train the generator for one step
    samples = None
    for it in range(5):
        batch, ques_len = gen_data_loader.next_batch()
        samples = generator.generate(sess, batch, ques_len)
        rewards = get_reward(sess, samples, 16, generator, discriminator)
        # print("rewards sample: ", rewards[0])
        feed = {
            generator.x: samples,
            generator.rewards: rewards,
            generator.target_sequence_length: ques_len,
            generator.max_sequence_length_per_batch: max(ques_len)
        }
        _, g_loss = sess.run([generator.g_updates, generator.g_loss],
                             feed_dict=feed)

    buffer = 'epoch:\t' + str(total_batch) + '\tg_loss:\t' + str(g_loss) + '\n'
    print('total_batch: ', total_batch, 'g_loss: ', g_loss)
    log.write(buffer)
Exemple #7
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size,
                          condition_size,
                          FEATURE_NUM,
                          BATCH_SIZE,
                          EMB_DIM,
                          COND_DIM,
                          HIDDEN_DIM,
                          Z_DIM,
                          SEQ_LENGTH,
                          START_TOKEN,
                          vocab_file,
                          condition_file,
                          word_vec=word_vec)

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # Checkpoint
    saver = tf.train.Saver()
    ckpt = get_ckpt(ckpt_dir)
    if ckpt is not None:
        print("Load checkpoints from: ", ckpt)
        saver.restore(sess, ckpt)

    # Load true data
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')

    #  pre-train generator
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        g_loss, lstm_loss, recon_loss, kl_loss = pre_train_epoch(
            sess, generator, gen_data_loader)
        if epoch % 10 == 0:
            log.write(
                'pre-train epoch %d, g_loss: %f, lstm_loss: %f, recon_loss: %f, kl_loss: %f\n'
                % (epoch, g_loss, lstm_loss, recon_loss, kl_loss))
            print(
                'pre-train epoch %d, g_loss: %f, lstm_loss: %f, recon_loss: %f, kl_loss: %f'
                % (epoch, g_loss, lstm_loss, recon_loss, kl_loss))
            generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                             generated_num, eval_file)

            if epoch % 20 == 0:
                saver.save(sess,
                           os.path.join(ckpt_dir, 'checkpoint_' + str(epoch)))

    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                         generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess, gen_data_loader.next_batch())
            rewards = rollout.get_reward(sess, samples, 16, discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # # Test
        # if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
        #     generate_samples(sess, generator, gen_data_loader, BATCH_SIZE, generated_num, eval_file)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            generate_samples(sess, generator, gen_data_loader, BATCH_SIZE,
                             generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
Exemple #8
0
def main():
    print('program start')
    from utils.text_process import text_precess, text_to_code  # TODO: move?
    from utils.text_process import get_tokenlized, get_word_list, get_dict

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    # JJ added
    SEQ_LENGTH, vocab_size = text_precess(true_file, val_file)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    gan_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE,
                                             SEQ_LENGTH)  # For testing
    #vocab_size = 5000

    # JJ added
    # Create training file and dicts
    tokens = get_tokenlized(true_file)
    val_tokens = get_tokenlized(val_file)
    word_set = get_word_list(tokens + val_tokens)
    [word_index_dict, index_word_dict] = get_dict(word_set)
    with open(oracle_file, 'w') as outfile:
        outfile.write(text_to_code(tokens, word_index_dict, SEQ_LENGTH))
    with open(val_oracle_file, 'w') as outfile:
        outfile.write(text_to_code(val_tokens, word_index_dict, SEQ_LENGTH))

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    #target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    mediator = Mediator(vocab_size,
                        BATCH_SIZE,
                        EMB_DIM * 2,
                        HIDDEN_DIM * 2,
                        SEQ_LENGTH,
                        START_TOKEN,
                        name="mediator",
                        dropout_rate=M_DROPOUT_RATE,
                        learning_rate=3e-3,
                        with_professor_forcing=False)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(oracle_file)  #positive_file)
    gan_data_loader.create_batches(oracle_file)  #positive_file)
    #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file)
    val_data_loader.create_batches(val_oracle_file)  #eval_file)

    log = open('save/experiment-log.txt', 'w')
    log_nll = open('save/experiment-log-nll.txt', 'w')
    #log_jsd = open('save/experiment-log-jsd.txt', 'w')

    #  pre-train generator (default 0 epochs)(not recommended)
    print('Start pre-training...')
    log.write('pre-training...\n')
    saver = tf.train.Saver(tf.global_variables())
    if RESTORE:
        saver.restore(sess, "saved_model/CoT")
    for epoch in range(PRE_EPOCH_NUM):
        loss = mle_epoch(sess, generator, gen_data_loader)
        if epoch % 1 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            likelihood_data_loader.create_batches(negative_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('pre-train epoch ', epoch, 'nll_oracle ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll_oracle:\t' + str(
                test_loss) + '\n'
            log_nll.write(buffer)
        if epoch % 1 == 0:
            test_loss = target_loss(sess, generator, val_data_loader)
            print('pre-train epoch ', epoch, 'nll_test ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll_test:\t' + str(
                test_loss) + '\n'
            log_nll.write(buffer)

    print(
        '#########################################################################'
    )
    toc = time.time()  # JJ
    print('Start Cooperative Training...')
    for iter_idx in range(TOTAL_BATCH):
        print('iteration: ' + str(iter_idx) + '\ntime: ' +
              str(time.time() - toc))
        toc = time.time()
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = mediator.get_reward(sess, samples)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(
                generator.g_updates, feed_dict=feed
            )  # JJ -> loss, _ = sess.run([generator.g_loss, generator.g_updates], feed_dict=feed)
        # Test
        # JJ delete
        '''
        if iter_idx % 100 == 0 or iter_idx == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            likelihood_data_loader.create_batches(negative_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'batch:\t' + str(iter_idx) + '\tnll_oracle:\t' + str(test_loss) + '\n'
            print('batch: ', iter_idx, 'nll_oracle: ', test_loss)
            log_nll.write(buffer)
        '''
        if iter_idx % gen_data_loader.num_batch == 0:  # epochs instead of batches #if iter_idx % 100 == 0:
            test_loss = target_loss(sess, generator, val_data_loader)
            print('epoch:\t', iter_idx // gen_data_loader.num_batch,
                  'nll_test ', test_loss)
            buffer = 'epoch:\t' + str(
                iter_idx // gen_data_loader.num_batch) + '\tnll_test:\t' + str(
                    test_loss) + '\n'
            #print('batch:\t', iter_idx, 'nll_test ', test_loss)
            #buffer = 'batch:\t'+ str(iter_idx) + '\tnll_test:\t' + str(test_loss) + '\n'
            log_nll.write(buffer)
            saver.save(sess, "saved_model/CoT")
        # Train the mediator
        for _ in range(1):
            bnll_ = []
            """
            d_loss_ = []
            for it in range(3):
                feed = {
                    mediator.x0: gan_data_loader.next_batch(),
                    mediator.x1: generator.generate(sess)
                }
                d_loss, _ = sess.run([mediator.d_loss, mediator.d_update], feed)
                d_loss_.append(d_loss)
            """
            for it in range(1):
                feed = {
                    mediator.x0: gen_data_loader.next_batch(),
                    mediator.x1: generator.generate(sess)
                }
                bnll = sess.run(mediator.likelihood_loss, feed)
                bnll_.append(bnll)
                sess.run(mediator.dropout_on)
                _ = sess.run(mediator.likelihood_updates, feed)
                sess.run(mediator.dropout_off)
            if iter_idx % 10 == 0:
                bnll = np.mean(bnll_)
                print("mediator cooptrain iter#%d, balanced_nll %f" %
                      (iter_idx, bnll))
                log.write("%d\t%f\n" % (iter_idx, bnll))
        #if iter_idx % gen_data_loader.num_batch == 0:
        #jsd = jsd_calculate(sess, generator, target_lstm)
        #print('cooptrain epoch#', iter_idx // gen_data_loader.num_batch, 'jsd ', jsd)
        #log_jsd.write("%d\t%f\n" % (iter_idx // gen_data_loader.num_batch, jsd))
        #saver.save(sess, "saved_model/CoT")
    log.close()
    log_nll.close()
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()

    np.random.seed(config_train.seed)

    assert config_train.start_token == 0
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    generator = Generator(config=config_gen)
    generator.build()

    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = pickle.load(open('save/target_params.pkl','rb'),encoding='iso-8859-1')
    target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model


    # Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()


    # Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    var_pretrained = [v for v in tf.trainable_variables() if 'teller' in v.name]
    gradients, variables = zip(
        *pretrained_optimizer.compute_gradients(generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_update = pretrained_optimizer.apply_gradients(zip(gradients, variables))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    generate_samples(sess,target_lstm,config_train.batch_size,config_train.generated_num,config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    log = open('save/experiment-log.txt','w')
    print('Start pre-training generator....')

    log.write('pre-training...\n')

    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _,g_loss = sess.run([gen_pre_update,generator.pretrained_loss],feed_dict={generator.input_seqs_pre:batch,
                                                                                      generator.input_seqs_mask:np.ones_like(batch)})

        if epoch % config_train.test_per_epoch == 0:
            #进行测试,通过Generator产生一批序列,
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.eval_file)
            # 创建这批序列的data-loader
            likelihood_data_loader.create_batches(config_train.eval_file)
            # 使用oracle 计算 交叉熵损失nll
            test_loss = target_loss(sess,target_lstm,likelihood_data_loader)
            # 打印并写入日志
            print('pre-train ',epoch, ' test_loss ',test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)


    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)
        for _ in range(config_train.dis_update_time_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch,y_batch = dis_data_loader.next_batch()
                feed_dict = {
                    discriminator.input_x : x_batch,
                    discriminator.input_y : y_batch,
                    discriminator.dropout_keep_prob : config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op,feed_dict)



    # Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    # Initialize global variables of optimizer for adversarial training
    uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    # Start adversarial training
    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshpae)

            feed = {'pred_seq_rollout:0':samples}
            reward_rollout = []
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,feed_dict=feed)
                # np.vstack 它是垂直(按照行顺序)的把数组给堆叠起来。
                rollout_list_stack = np.vstack(rollout_list)
                reward_rollout_seq = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:rollout_list_stack,discriminator.dropout_keep_prob:1.0
                })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:samples,discriminator.dropout_keep_prob:1.0
                })
                reward_allseq = np.concatenate((reward_rollout_seq,reward_last_tok),axis=0)[:,1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(r,config_gen.gen_batch_size * config_gen.sequence_length,config_gen.gen_batch_size)])

                reward_rollout.append(np.array(reward_tmp))
                rewards = np.sum(reward_rollout,axis = 0) / config_train.rollout_num
                _,gen_loss = sess.run([train_adv_update,generator.gen_loss_adv],feed_dict={generator.input_seqs_adv:samples,
                                                                                           generator.rewards:rewards})


        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print ('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)


        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)

            for _ in range(config_train.dis_update_time_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch,y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:x_batch,
                        discriminator.input_y:y_batch,
                        discriminator.dropout_keep_prob:config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op,feed)

    log.close()
def main():
    print('program start')
    from utils.text_process import text_precess, text_to_code  # TODO: move?
    from utils.text_process import get_tokenlized, get_word_list, get_dict

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    SEQ_LENGTH, vocab_size = text_precess(true_file, val_file)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)

    # Create training file and dicts
    tokens = get_tokenlized(true_file)
    val_tokens = get_tokenlized(val_file)
    word_set = get_word_list(tokens + val_tokens)
    [word_index_dict, index_word_dict] = get_dict(word_set)
    with open(oracle_file, 'w') as outfile:
        outfile.write(text_to_code(tokens, word_index_dict, SEQ_LENGTH))
    with open(val_oracle_file, 'w') as outfile:
        outfile.write(text_to_code(val_tokens, word_index_dict, SEQ_LENGTH))

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    #target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model
    # replace target lstm with true data

    mediator = Generator(vocab_size,
                         BATCH_SIZE * 2,
                         EMB_DIM * 2,
                         HIDDEN_DIM * 2,
                         SEQ_LENGTH,
                         START_TOKEN,
                         name="mediator",
                         dropout_rate=M_DROPOUT_RATE)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    gen_data_loader.create_batches(oracle_file)
    val_data_loader.create_batches(val_oracle_file)

    log = open('save/experiment-log.txt', 'w')
    log_nll = open('save/experiment-log-nll.txt', 'w')

    #  pre-train generator (default 0 epochs)(not recommended)
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = mle_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             generator_file)
            #get_real_test_file(index_word_dict, generator_file, test_file) # only needed in debugging
            test_loss = target_loss(sess, generator, val_data_loader)
            print('pre-train epoch ', epoch, 'nll_test ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll_test:\t' + str(
                test_loss) + '\n'
            log_nll.write(buffer)

    print(
        '#########################################################################'
    )
    toc = time.time()
    print('Start Cooperative Training...')
    for iter_idx in range(TOTAL_BATCH):
        print('iteration: ' + str(iter_idx) + '\ntime: ' +
              str(time.time() - toc))
        toc = time.time()
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            rewards = mediator.get_reward(
                sess, np.concatenate([samples, samples], axis=0))
            feed = {
                generator.x: samples,
                generator.rewards: rewards[0:BATCH_SIZE]
            }
            loss, _ = sess.run([generator.g_loss, generator.g_updates],
                               feed_dict=feed)
        # Test, removed oracle test
        if iter_idx % gen_data_loader.num_batch == 0:  # epochs instead of batches
            test_loss = target_loss(sess, generator, val_data_loader)
            print('epoch:\t', iter_idx // gen_data_loader.num_batch,
                  'nll_test ', test_loss)
            buffer = 'epoch:\t' + str(
                iter_idx // gen_data_loader.num_batch) + '\tnll_test:\t' + str(
                    test_loss) + '\n'
            log_nll.write(buffer)
        if iter_idx == TOTAL_BATCH - 1:
            print('generating samples')
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             generator_file)
            get_real_test_file(index_word_dict, generator_file, test_file)
        # Train the mediator
        for _ in range(1):
            print('training mediator...')
            bnll_ = []
            collected_x = []
            ratio = 2
            for it in range(ratio):
                if it % 2 == 0:
                    x_batch = gen_data_loader.next_batch()
                else:
                    x_batch = generator.generate(sess)
                collected_x.append(x_batch)
            collected_x = np.reshape(collected_x, [-1, SEQ_LENGTH])
            np.random.shuffle(collected_x)
            collected_x = np.reshape(collected_x,
                                     [-1, BATCH_SIZE * 2, SEQ_LENGTH])
            for it in range(1):
                feed = {
                    mediator.x: collected_x[it],
                }
                print('running bnll sess')
                bnll = sess.run(mediator.likelihood_loss, feed)
                bnll_.append(bnll)
                print('running mediator and updating')
                sess.run(mediator.dropout_on)
                _ = sess.run(mediator.likelihood_updates, feed)
                sess.run(mediator.dropout_off)
            if iter_idx % 50 == 0:
                bnll = np.mean(bnll_)
                print("mediator cooptrain iter#%d, balanced_nll %f" %
                      (iter_idx, bnll))
                log.write("%d\t%f\n" % (iter_idx, bnll))

    log.close()
    log_nll.close()
Exemple #11
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    SEQ_LENGTH, vocab_size = text_precess(train_set)

    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    val_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH) # For testing

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    target_params = pickle.load(open('save/target_params_py3.pkl', 'rb'))
    #target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, 32, 32, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    mediator = Generator(vocab_size, BATCH_SIZE*2, EMB_DIM*2, HIDDEN_DIM*2, SEQ_LENGTH, START_TOKEN, name="mediator", dropout_rate=M_DROPOUT_RATE)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # create training set indices
    tokens = get_tokenlized(train_set)
    word_set = get_word_list(tokens)
    [word_index_dict, index_word_dict] = get_dict(word_set)
    with open(positive_file, 'w') as outfile:
        outfile.write(text_to_code(tokens, word_index_dict, SEQ_LENGTH))

    # create and load batches from index training set
    gen_data_loader.create_batches(positive_file)


    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    #gen_data_loader.create_batches(positive_file) # use training file
    #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, eval_file)
    #val_data_loader.create_batches(eval_file)

    log = open('save/experiment-log' + str(time()) + '.txt', 'w')
    log_nll = open('save/experiment-log-nll' + str(time()) + '.txt', 'w')

    print('#########################################################################')
    print('Start Cooperative Training...')
    print('Num batches: ' + str(gen_data_loader.num_batch))
    for iter_idx in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            print('Training G')
            samples = generator.generate(sess)
            rewards = mediator.get_reward(sess, np.concatenate([samples, samples], axis=0))
            feed = {generator.x: samples, generator.rewards: rewards[0:BATCH_SIZE]}
            loss, _ = sess.run([generator.g_loss, generator.g_updates], feed_dict=feed)
            #print(loss) # remove, to often?
            #_ = sess.run(generator.g_updates, feed_dict=feed)
            if iter_idx % gen_data_loader.num_batch == 0:
                print('cooptrain epoch#', iter_idx // gen_data_loader.num_batch)
                print('loss: ' + str(loss))
        # Test, removed oracle
        if iter_idx % 100 == 0 or iter_idx == TOTAL_BATCH - 1:
            print('Generating fake samples')
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            likelihood_data_loader.create_batches(negative_file)
            print('Calculating NLL')
            test_loss = target_loss(sess, generator, gen_data_loader) # use validation generator? Texygen uses same
            print('batch:\t', iter_idx, 'nll_test ', test_loss)
            buffer = 'batch:\t'+ str(iter_idx) + '\tnll_test:\t' + str(test_loss) + '\n'
            log_nll.write(buffer)
        # Train the mediator
        for _ in range(1):
            print('Training M')
            bnll_ = []
            collected_x = []
            ratio = 2
            for it in range(ratio):
                if it % 2 == 0:
                    x_batch = gen_data_loader.next_batch()
                else:
                    x_batch = generator.generate(sess)
                collected_x.append(x_batch)
            collected_x = np.reshape(collected_x, [-1, SEQ_LENGTH])
            np.random.shuffle(collected_x)
            collected_x = np.reshape(collected_x, [-1, BATCH_SIZE*2, SEQ_LENGTH])
            for it in range(1):
                print('Calculating BNLL')
                feed = {
                    mediator.x: collected_x[it],
                }
                bnll = sess.run(mediator.likelihood_loss, feed)
                bnll_.append(bnll)
                # sess.run(mediator.dropout_on)
                _ = sess.run(mediator.likelihood_updates, feed)
                # sess.run(mediator.dropout_off)
        if (iter_idx * 4) % gen_data_loader.num_batch == 0:
            print('Calculating likelihood loss for M')
            bnll = np.mean(bnll_)
            gnll = sess.run(mediator.likelihood_loss, feed_dict={mediator.x: np.reshape([generator.generate(sess), generator.generate(sess)], [BATCH_SIZE*2, SEQ_LENGTH])})
            print("mediator cooptrain iter#%d, balanced_nll %f, g_nll %f" % (iter_idx, bnll, gnll))
            log.write("%d\t%f\n" % (iter_idx, bnll))

    log.close()
    log_nll.close()