def disc_pre_train(text_data):
    train_set = gens.create_disc_train_set(gen_config, text_data, -1, None,
                                           gen_config.disc_data_batch_num)
    h_disc.hier_train(disc_config, evl_config, text_data.getVocabularySize(),
                      train_set)
def al_train(text_data):
    with tf.Session() as sess:
        train_set = gens.create_train_set(gen_config, text_data)

        total_qa_size = 0
        for i, set in enumerate(train_set):
            length = len(set)
            print("Generator train_set_{} len: {}".format(i, length))
            total_qa_size += length
        print("Generator train_set total size is {} QA".format(total_qa_size))

        train_bucket_sizes = [
            len(train_set[b]) for b in range(len(gen_config.buckets))
        ]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in range(len(train_bucket_sizes))
        ]
        vocab_size = text_data.getVocabularySize()
        disc_model = h_disc.create_model(sess, disc_config, vocab_size,
                                         disc_config.name_model)
        gen_model = gens.create_model(sess,
                                      gen_config,
                                      vocab_size,
                                      forward_only=False,
                                      name_scope=gen_config.name_model)

        current_step = 0
        step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
        gen_loss_summary = tf.Summary()
        disc_loss_summary = tf.Summary()

        gen_writer = tf.summary.FileWriter(gen_config.tensorboard_dir,
                                           sess.graph)
        disc_writer = tf.summary.FileWriter(disc_config.tensorboard_dir,
                                            sess.graph)

        while True:
            current_step += 1
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in range(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])
            start_time = time.time()
            print(
                "==================Update Discriminator: %d=================="
                % current_step)
            for i in range(D_STEPS):
                print(
                    "=============It's the %d time update Discriminator in current step============="
                    % (i + 1))

                # 1. Sample (X,Y) from real data and sample ^Y from G(*|X)
                query_set, answer_set, gen_set = gens.create_disc_train_set(
                    gen_config, text_data, bucket_id, train_set, 1, sess,
                    gen_model)

                b_query, b_answer, b_gen = query_set[bucket_id], answer_set[
                    bucket_id], gen_set[bucket_id]

                train_query, train_answer, train_labels = h_disc.hier_get_batch(
                    disc_config,
                    len(b_query) - 1, b_query, b_answer, b_gen)
                train_query = np.transpose(train_query)
                train_answer = np.transpose(train_answer)

                _, disc_step_loss = disc_step(sess,
                                              bucket_id,
                                              disc_model,
                                              train_query,
                                              train_answer,
                                              train_labels,
                                              forward_only=False)
                disc_loss += disc_step_loss / (
                    D_STEPS * disc_config.steps_per_checkpoint)
                if i == D_STEPS - 1:
                    print("disc_step_loss: ", disc_step_loss)

            print("==================Update Generator: %d==================" %
                  current_step)
            for j in range(G_STEPS):
                print(
                    "=============It's the %d time update Generator in current step============="
                    % (j + 1))
                encoder_inputs, decoder_inputs, target_weights,\
                    source_inputs, source_outputs = gens.get_batch(gen_config, train_set, bucket_id,
                                                                   gen_config.batch_size, text_data)

                decoder_inputs_negative = get_negative_decoder_inputs(
                    sess, gen_model, encoder_inputs, decoder_inputs,
                    target_weights, bucket_id)
                decoder_inputs_negative = np.transpose(decoder_inputs_negative)

                train_query, train_answer, train_labels = [], [], []
                for query, answer in zip(source_inputs, source_outputs):
                    train_query.append(query)
                    train_answer.append(answer)
                    train_labels.append(1)
                for _ in range(gen_config.beam_size):
                    gen_set = get_negative_decoder_inputs(sess,
                                                          gen_model,
                                                          encoder_inputs,
                                                          decoder_inputs,
                                                          target_weights,
                                                          bucket_id,
                                                          mc_search=True)
                    for i, output in enumerate(gen_set):
                        train_query.append(train_query[i])
                        train_answer.append(output)
                        train_labels.append(0)

                train_query = np.transpose(train_query)
                train_answer = np.transpose(train_answer)

                reward, _ = disc_step(sess,
                                      bucket_id,
                                      disc_model,
                                      train_query,
                                      train_answer,
                                      train_labels,
                                      forward_only=True)
                batch_reward += reward / gen_config.steps_per_checkpoint
                print("step_reward: ", reward)

                gan_adjusted_loss, gen_step_loss, _ = gen_model.step(
                    sess,
                    encoder_inputs,
                    decoder_inputs_negative,
                    target_weights,
                    bucket_id,
                    forward_only=False,
                    reward=reward,
                    up_reward=True,
                    debug=True)
                gen_loss += gen_step_loss / gen_config.steps_per_checkpoint

                print("gen_step_loss: ", gen_step_loss)
                print("gen_step_adjusted_loss: ", gan_adjusted_loss)

                t_adjusted_loss, t_step_loss, a = gen_model.step(
                    sess,
                    encoder_inputs,
                    decoder_inputs,
                    target_weights,
                    bucket_id,
                    forward_only=False)
                t_loss += t_step_loss / (G_STEPS *
                                         gen_config.steps_per_checkpoint)

                print("t_step_loss: ", t_step_loss)
                print("t_adjusted_loss", t_adjusted_loss)

            if current_step % gen_config.steps_per_checkpoint == 0:

                step_time += (time.time() -
                              start_time) / gen_config.steps_per_checkpoint

                print(
                    "current_steps: %d, step time: %.4f, disc_loss: %.3f, gen_loss: %.3f, t_loss: %.3f, reward: %.3f "
                    % (current_step, step_time, disc_loss, gen_loss, t_loss,
                       batch_reward))

                disc_loss_value = disc_loss_summary.value.add()
                disc_loss_value.tag = disc_config.name_loss
                disc_loss_value.simple_value = float(disc_loss)
                disc_writer.add_summary(disc_loss_summary,
                                        int(sess.run(disc_model.global_step)))

                gen_global_steps = sess.run(gen_model.global_step)
                gen_loss_value = gen_loss_summary.value.add()
                gen_loss_value.tag = gen_config.name_loss
                gen_loss_value.simple_value = float(gen_loss)
                t_loss_value = gen_loss_summary.value.add()
                t_loss_value.tag = gen_config.teacher_loss
                t_loss_value.simple_value = float(t_loss)
                batch_reward_value = gen_loss_summary.value.add()
                batch_reward_value.tag = gen_config.reward_name
                batch_reward_value.simple_value = float(batch_reward)
                gen_writer.add_summary(gen_loss_summary, int(gen_global_steps))

                if current_step % (gen_config.steps_per_checkpoint * 4) == 0:
                    print("current_steps: %d, save disc model" % current_step)
                    disc_ckpt_dir = os.path.abspath(
                        os.path.join(disc_config.train_dir, "checkpoints"))
                    if not os.path.exists(disc_ckpt_dir):
                        os.makedirs(disc_ckpt_dir)
                    disc_model_path = os.path.join(disc_ckpt_dir, "disc.model")
                    disc_model.saver.save(sess,
                                          disc_model_path,
                                          global_step=disc_model.global_step)

                    print("current_steps: %d, save gen model" % current_step)
                    gen_ckpt_dir = os.path.abspath(
                        os.path.join(gen_config.train_dir, "checkpoints"))
                    if not os.path.exists(gen_ckpt_dir):
                        os.makedirs(gen_ckpt_dir)
                    gen_model_path = os.path.join(gen_ckpt_dir, "gen.model")
                    gen_model.saver.save(sess,
                                         gen_model_path,
                                         global_step=gen_model.global_step)

                step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
                sys.stdout.flush()