Exemplo n.º 1
0
def create_model(sess, data, args, embed):
    #get maximum input sequence length
    data.restart("train", batch_size=args.batch_size, shuffle=True)
    batched_data = data.get_next_batch("train")
    length = []
    while batched_data != None:
        length.append(len(batched_data["sent"][0]))
        batched_data = data.get_next_batch("train")
    sequence_length = np.max(length)

    latest_dir = '%s/checkpoint_latest' % args.model_dir
    best_dir = '%s/checkpoint_best' % args.model_dir

    summary = create_summary(args)
    with tf.variable_scope("generator"):
        #Build generator and its rollout
        generator = Generator(args, data, embed, summary, sequence_length,
                              latest_dir, best_dir)
        generator.build()
        rollout_gen = rollout(args, data, embed, sequence_length)

    with tf.variable_scope("discriminator"):
        #Build discriminator
        discriminator = Discriminator(args, data, embed, summary,
                                      sequence_length, latest_dir, best_dir)
        discriminator.build_discriminator()

    latest_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                  max_to_keep=args.checkpoint_max_to_keep,
                                  pad_step_number=True,
                                  keep_checkpoint_every_n_hours=1.0)
    best_saver = tf.train.Saver(write_version=tf.train.SaverDef.V2,
                                max_to_keep=1,
                                pad_step_number=True,
                                keep_checkpoint_every_n_hours=1.0)

    if tf.train.get_checkpoint_state(latest_dir) and args.restore == "last":
        print("Reading model parameters from %s" % latest_dir)
        latest_saver.restore(sess, tf.train.latest_checkpoint(latest_dir))
    else:
        if tf.train.get_checkpoint_state(best_dir) and args.restore == "best":
            print('Reading model parameters from %s' % best_dir)
            best_saver.restore(sess, tf.train.latest_checkpoint(best_dir))
        else:
            print("Created model with fresh parameters.")
            sess.run(tf.variables_initializer(tf.global_variables()))
    generator.latest_saver, generator.best_saver = latest_saver, best_saver
    discriminator.latest_saver, discriminator.best_saver = latest_saver, best_saver
    '''
    generator.print_parameters()
    print("-----------------------------")
    discriminator.print_parameters()
    print("-----------------------------")
    '''
    return generator, discriminator, rollout_gen
Exemplo n.º 2
0
    def __init__(self, img_length, num_colors, d_sizes, g_sizes):
        self.img_length = img_length
        self.num_colors = num_colors
        self.latent_dims = g_sizes["z"]
        self.X = tf.placeholder(tf.float32, shape = (None,img_length,img_length,num_colors), name = "X")
        self.Z = tf.placeholder(tf.float32, shape=(None,self.latent_dims), name="Z")
        self.batch_sz = tf.placeholder(tf.int32, shape=(), name = "batch_sz")
        dnrt = Discriminator(self.img_length, self.latent_dims, 64, num_colors)
        gnrt = Generator(self.img_length, self.latent_dims, 64)
        logits = dnrt.build_discriminator(self.X, d_sizes, DenseLayer, ConvLayer)
        self.sample_images = gnrt.build_generator(self.Z, g_sizes, DenseLayer, FractionallyStridedConvLayer)

        with tf.variable_scope("discriminator") as scope:
            scope.reuse_variables()
            sample_logits = dnrt.d_forward(self.sample_images, True)

        with tf.variable_scope("generator") as scope:
            scope.reuse_variables()
            self.sample_images_test = gnrt.g_forward(self.Z, reuse = True, is_training = False)

        d = tf.reduce_sum(self.sample_images - self.sample_images_test)+ 7.5
        print ("anuj", d)
        self.d_cost_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits))
        self.d_cost_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits= sample_logits,
                                                                   labels=tf.zeros_like(sample_logits))
        self.d_cost = tf.reduce_mean(self.d_cost_real) + tf.reduce_mean(self.d_cost_fake)

        self.g_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = sample_logits,
                                                                             labels=tf.ones_like(sample_logits)))
        real_predictions = tf.cast(logits > 0, tf.float32)
        fake_predictions = tf.cast( sample_logits < 0, tf.float32)
        num_predictions = 2*BATCH_SIZE
        num_correct = tf.reduce_sum(real_predictions) + tf.reduce_sum(fake_predictions)
        self.d_accuracy = num_correct/num_predictions


        #optimizers
        self.d_params = [t for t in tf.trainable_variables() if t.name.startswith("d")]
        self.g_params = [t for t in tf.trainable_variables() if t.name.startswith("g")]

        self.d_train_op = tf.train.AdamOptimizer(LEARNING_RATE, beta1=BETA1).minimize(self.d_cost,
                                                                                      var_list=self.d_params)
        self.g_train_op = tf.train.AdamOptimizer(LEARNING_RATE, beta1=BETA1).minimize(self.g_cost,
                                                                                      var_list=self.g_params)
        self.init_op = tf.global_variables_initializer()
        self.sess = tf.InteractiveSession()
        self.sess.run(self.init_op)
Exemplo n.º 3
0
    def __init__(self, rows, cols, channels, data, input_size=128):
        self.outputdir = self.getOutput()
        # Input shape
        self.img_rows = rows
        self.img_cols = cols
        self.channels = channels
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        ## Encoded dimensions (Autoencoder)
        self.latent_dim = 100
        ## Loaded data
        self.data = data
        ## What is input size?
        self.input_size = input_size
        ## Define optimizer
        optimizer = Adam(0.0002, 0.5)
        self.setupLogs()

        # Build and compile the discriminator
        dc = Discriminator(self.img_shape)
        self.discriminator = dc.build_discriminator(input_size=32)
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        # Build the generator
        gn = Generator(self.img_shape, self.input_size, self.latent_dim,
                       self.channels)
        self.generator = gn.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim, ))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, valid)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
Exemplo n.º 4
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()
    np.random.seed(config_train.seed)
    assert config_train.start_token == 0

    #Build dataloader for generaotr, testing and discriminator
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    #Build generator and its rollout
    generator = Generator(config=config_gen)
    generator.build()
    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = cPickle.load(StrToBytes(open('save/target_params.pkl')),
                                 encoding='bytes')
    target_lstm = TARGET_LSTM(config=config_gen,
                              params=target_params)  # The oracle model

    #Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    #Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(
        config_train.gen_learning_rate)
    var_pretrained = [
        v for v in tf.trainable_variables() if 'teller' in v.name
    ]  #Using name 'teller' here to prevent name collision of target LSTM
    gradients, variables = zip(*pretrained_optimizer.compute_gradients(
        generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_upate = pretrained_optimizer.apply_gradients(
        zip(gradients, variables))

    #Initialize all variables
    sess = tf.Session(config=config_hardware)
    sess.run(tf.global_variables_initializer())

    #Initalize data loader of generator
    # generate_samples(sess, target_lstm, config_train.batch_size, config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    #Start pretraining
    log = open('save/experiment-log.txt', 'w')
    print('Start pre-training generator...')
    log.write('pre-training...\n')
    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\
                                                                                    generator.input_seqs_mask:np.ones_like(batch)})
        if epoch % config_train.test_per_epoch == 0:
            # generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess, generator, config_train.batch_size,
                         config_train.generated_num,
                         config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,
                                        config_train.negative_file)
        for _ in range(config_train.dis_update_epoch_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x:
                    x_batch,
                    discriminator.input_y:
                    y_batch,
                    discriminator.dropout_keep_prob:
                    config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    #Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(
        generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    #Initialize global variables of optimizer for adversarial training
    uninitialized_var = [
        e for e in tf.global_variables() if e not in tf.trainable_variables()
    ]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    #Start adversarial training
    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshape)

            feed = {"pred_seq_rollout:0": samples}
            reward_rollout = []
            #calcuate the reward given in the specific stpe t by roll out
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,
                                        feed_dict=feed)
                rollout_list_stack = np.vstack(
                    rollout_list
                )  #shape: #batch_size * #rollout_step, #sequence length
                reward_rollout_seq = sess.run(
                    discriminator.ypred_for_auc,
                    feed_dict={
                        discriminator.input_x: rollout_list_stack,
                        discriminator.dropout_keep_prob: 1.0
                    })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,
                                           feed_dict={
                                               discriminator.input_x: samples,
                                               discriminator.dropout_keep_prob:
                                               1.0
                                           })
                reward_allseq = np.concatenate(
                    (reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(
                        r,
                        config_gen.gen_batch_size * config_gen.sequence_length,
                        config_gen.gen_batch_size)])
                reward_rollout.append(np.array(reward_tmp))
            rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num
            _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\
                                                                                        generator.rewards:rewards})
        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)

        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,
                                            config_train.negative_file)

            for _ in range(config_train.dis_update_epoch_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:
                        x_batch,
                        discriminator.input_y:
                        y_batch,
                        discriminator.dropout_keep_prob:
                        config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
    log.close()
Exemplo n.º 5
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()
    np.random.seed(config_train.seed)
    assert config_train.start_token == 0

    #Build dataloader for generaotr, testing and discriminator
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    #Build generator and its rollout
    generator = Generator(config=config_gen)
    # 生成 3个神经网络
    generator.build()
    #  快速展开网络,序列未生成完就预测后边的序列,用于计算reward
    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(config=config_gen,
                              params=target_params)  # The oracle model

    #Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()

    #Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(
        config_train.gen_learning_rate)
    # 取出 teller 的所有变量, teller在 generator和rollout网络中
    var_pretrained = [
        v for v in tf.trainable_variables() if 'teller' in v.name
    ]  #Using name 'teller' here to prevent name collision of target LSTM
    # zip函数将 2个迭代器  组成tuple
    gradients, variables = zip(*pretrained_optimizer.compute_gradients(
        generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_upate = pretrained_optimizer.apply_gradients(
        zip(gradients, variables))

    #Initialize all variables
    sess = tf.Session(config=config_hardware)
    sess.run(tf.global_variables_initializer())

    #Initalize data loader of generator   utils.py文件中
    #   target_lstm 网络生成真实数据 写入config_train.positive_file 文件
    generate_samples(sess, target_lstm, config_train.batch_size,
                     config_train.generated_num, config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    #Start pretraining
    log = open('save/experiment-log.txt', 'w')
    print 'Start pre-training generator...'
    log.write('pre-training...\n')
    for epoch in xrange(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in xrange(gen_data_loader.num_batch):
            #见第60行,加载target_lstm 神经网络的数据,用于预训练生成器====真实样本
            batch = gen_data_loader.next_batch()
            #真实数据训练  generator;有监督学习   batch 最后第一个是label
            _, g_loss = sess.run([gen_pre_upate, generator.pretrained_loss], feed_dict={generator.input_seqs_pre:batch,\
                                                                                    generator.input_seqs_mask:np.ones_like(batch)})
        if epoch % config_train.test_per_epoch == 0:
            #  generator 生成样本  与 真实数据的相似度
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            #评估生成质量
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)

    print 'Start pre-training discriminator...'
    for t in range(config_train.dis_update_time_pre):
        print "Times: " + str(t)
        #   generator生成假数据+ target_lstm的真实数据;; 用于训练
        generate_samples(sess, generator, config_train.batch_size,
                         config_train.generated_num,
                         config_train.negative_file)
        #  混合真假数据
        dis_data_loader.load_train_data(config_train.positive_file,
                                        config_train.negative_file)
        for _ in range(config_train.dis_update_epoch_pre):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x:
                    x_batch,
                    discriminator.input_y:
                    y_batch,
                    discriminator.dropout_keep_prob:
                    config_dis.dis_dropout_keep_prob
                }
                #交叉上最小;  主要是训练评分网络 用于给generator提供reward
                _ = sess.run(discriminator.train_op, feed)

    #Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(
        generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    #Initialize global variables of optimizer for adversarial training
    uninitialized_var = [
        e for e in tf.global_variables() if e not in tf.trainable_variables()
    ]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    #Start adversarial training   开始对抗训练
    for total_batch in xrange(config_train.total_batch):
        for iter_gen in xrange(config_train.gen_update_time):

            #  用generator进行抽样; LSTM 生成序列
            samples = sess.run(generator.sample_word_list_reshape)

            feed = {"pred_seq_rollout:0": samples}
            reward_rollout = []
            #calcuate the reward given in the specific stpe t by roll out
            # 用rollout网络计算指定动作的回报
            for iter_roll in xrange(config_train.rollout_num):

                # 生成器采样的获得的单词传给 rollout  ??有一个疑问?samples看代码是完整序列(与论文不符),为什么还要rollout
                rollout_list = sess.run(rollout_gen.sample_rollout_step,
                                        feed_dict=feed)

                rollout_list_stack = np.vstack(
                    rollout_list
                )  #shape: #batch_size * #rollout_step, #sequence length
                # 蒙特卡洛 展开成序列,贝尔曼方程计算 reward
                reward_rollout_seq = sess.run(
                    discriminator.ypred_for_auc,
                    feed_dict={
                        discriminator.input_x: rollout_list_stack,
                        discriminator.dropout_keep_prob: 1.0
                    })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,
                                           feed_dict={
                                               discriminator.input_x: samples,
                                               discriminator.dropout_keep_prob:
                                               1.0
                                           })
                reward_allseq = np.concatenate(
                    (reward_rollout_seq, reward_last_tok), axis=0)[:, 1]
                reward_tmp = []
                for r in xrange(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(
                        r,
                        config_gen.gen_batch_size * config_gen.sequence_length,
                        config_gen.gen_batch_size)])
                reward_rollout.append(np.array(reward_tmp))
            #计算reward
            rewards = np.sum(reward_rollout, axis=0) / config_train.rollout_num
            # 用reward 指导 generator 更新梯度
            _, gen_loss = sess.run([train_adv_update, generator.gen_loss_adv], feed_dict={generator.input_seqs_adv:samples,\
                                                                                        generator.rewards:rewards})
        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            #对抗训练后 用generator再次生成样本与模拟器(target_lstm,真实数据)进行比对
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            #util.py中定义
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess, generator, config_train.batch_size,
                             config_train.generated_num,
                             config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,
                                            config_train.negative_file)

            for _ in range(config_train.dis_update_epoch_adv):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:
                        x_batch,
                        discriminator.input_y:
                        y_batch,
                        discriminator.dropout_keep_prob:
                        config_dis.dis_dropout_keep_prob
                    }
                    #训练这个评分网络, score
                    _ = sess.run(discriminator.train_op, feed)
    log.close()
Exemplo n.º 6
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()

    np.random.seed(config_train.seed)

    assert config_train.start_token == 0
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    generator = Generator(config=config_gen)
    generator.build()

    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = pickle.load(open('save/target_params.pkl','rb'),encoding='iso-8859-1')
    target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model


    # Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()


    # Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    var_pretrained = [v for v in tf.trainable_variables() if 'teller' in v.name]
    gradients, variables = zip(
        *pretrained_optimizer.compute_gradients(generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_update = pretrained_optimizer.apply_gradients(zip(gradients, variables))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    generate_samples(sess,target_lstm,config_train.batch_size,config_train.generated_num,config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    log = open('save/experiment-log.txt','w')
    print('Start pre-training generator....')

    log.write('pre-training...\n')

    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _,g_loss = sess.run([gen_pre_update,generator.pretrained_loss],feed_dict={generator.input_seqs_pre:batch,
                                                                                      generator.input_seqs_mask:np.ones_like(batch)})

        if epoch % config_train.test_per_epoch == 0:
            #进行测试,通过Generator产生一批序列,
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.eval_file)
            # 创建这批序列的data-loader
            likelihood_data_loader.create_batches(config_train.eval_file)
            # 使用oracle 计算 交叉熵损失nll
            test_loss = target_loss(sess,target_lstm,likelihood_data_loader)
            # 打印并写入日志
            print('pre-train ',epoch, ' test_loss ',test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)


    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)
        for _ in range(config_train.dis_update_time_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch,y_batch = dis_data_loader.next_batch()
                feed_dict = {
                    discriminator.input_x : x_batch,
                    discriminator.input_y : y_batch,
                    discriminator.dropout_keep_prob : config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op,feed_dict)



    # Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    # Initialize global variables of optimizer for adversarial training
    uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    # Start adversarial training
    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshpae)

            feed = {'pred_seq_rollout:0':samples}
            reward_rollout = []
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,feed_dict=feed)
                # np.vstack 它是垂直(按照行顺序)的把数组给堆叠起来。
                rollout_list_stack = np.vstack(rollout_list)
                reward_rollout_seq = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:rollout_list_stack,discriminator.dropout_keep_prob:1.0
                })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:samples,discriminator.dropout_keep_prob:1.0
                })
                reward_allseq = np.concatenate((reward_rollout_seq,reward_last_tok),axis=0)[:,1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(r,config_gen.gen_batch_size * config_gen.sequence_length,config_gen.gen_batch_size)])

                reward_rollout.append(np.array(reward_tmp))
                rewards = np.sum(reward_rollout,axis = 0) / config_train.rollout_num
                _,gen_loss = sess.run([train_adv_update,generator.gen_loss_adv],feed_dict={generator.input_seqs_adv:samples,
                                                                                           generator.rewards:rewards})


        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print ('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)


        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)

            for _ in range(config_train.dis_update_time_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch,y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:x_batch,
                        discriminator.input_y:y_batch,
                        discriminator.dropout_keep_prob:config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op,feed)

    log.close()