def al_train(text_data):
    with tf.Session() as sess:
        train_set = gens.create_train_set(gen_config, text_data)

        total_qa_size = 0
        for i, set in enumerate(train_set):
            length = len(set)
            print("Generator train_set_{} len: {}".format(i, length))
            total_qa_size += length
        print("Generator train_set total size is {} QA".format(total_qa_size))

        train_bucket_sizes = [
            len(train_set[b]) for b in range(len(gen_config.buckets))
        ]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in range(len(train_bucket_sizes))
        ]
        vocab_size = text_data.getVocabularySize()
        disc_model = h_disc.create_model(sess, disc_config, vocab_size,
                                         disc_config.name_model)
        gen_model = gens.create_model(sess,
                                      gen_config,
                                      vocab_size,
                                      forward_only=False,
                                      name_scope=gen_config.name_model)

        current_step = 0
        step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
        gen_loss_summary = tf.Summary()
        disc_loss_summary = tf.Summary()

        gen_writer = tf.summary.FileWriter(gen_config.tensorboard_dir,
                                           sess.graph)
        disc_writer = tf.summary.FileWriter(disc_config.tensorboard_dir,
                                            sess.graph)

        while True:
            current_step += 1
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in range(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])
            start_time = time.time()
            print(
                "==================Update Discriminator: %d=================="
                % current_step)
            for i in range(D_STEPS):
                print(
                    "=============It's the %d time update Discriminator in current step============="
                    % (i + 1))

                # 1. Sample (X,Y) from real data and sample ^Y from G(*|X)
                query_set, answer_set, gen_set = gens.create_disc_train_set(
                    gen_config, text_data, bucket_id, train_set, 1, sess,
                    gen_model)

                b_query, b_answer, b_gen = query_set[bucket_id], answer_set[
                    bucket_id], gen_set[bucket_id]

                train_query, train_answer, train_labels = h_disc.hier_get_batch(
                    disc_config,
                    len(b_query) - 1, b_query, b_answer, b_gen)
                train_query = np.transpose(train_query)
                train_answer = np.transpose(train_answer)

                _, disc_step_loss = disc_step(sess,
                                              bucket_id,
                                              disc_model,
                                              train_query,
                                              train_answer,
                                              train_labels,
                                              forward_only=False)
                disc_loss += disc_step_loss / (
                    D_STEPS * disc_config.steps_per_checkpoint)
                if i == D_STEPS - 1:
                    print("disc_step_loss: ", disc_step_loss)

            print("==================Update Generator: %d==================" %
                  current_step)
            for j in range(G_STEPS):
                print(
                    "=============It's the %d time update Generator in current step============="
                    % (j + 1))
                encoder_inputs, decoder_inputs, target_weights,\
                    source_inputs, source_outputs = gens.get_batch(gen_config, train_set, bucket_id,
                                                                   gen_config.batch_size, text_data)

                decoder_inputs_negative = get_negative_decoder_inputs(
                    sess, gen_model, encoder_inputs, decoder_inputs,
                    target_weights, bucket_id)
                decoder_inputs_negative = np.transpose(decoder_inputs_negative)

                train_query, train_answer, train_labels = [], [], []
                for query, answer in zip(source_inputs, source_outputs):
                    train_query.append(query)
                    train_answer.append(answer)
                    train_labels.append(1)
                for _ in range(gen_config.beam_size):
                    gen_set = get_negative_decoder_inputs(sess,
                                                          gen_model,
                                                          encoder_inputs,
                                                          decoder_inputs,
                                                          target_weights,
                                                          bucket_id,
                                                          mc_search=True)
                    for i, output in enumerate(gen_set):
                        train_query.append(train_query[i])
                        train_answer.append(output)
                        train_labels.append(0)

                train_query = np.transpose(train_query)
                train_answer = np.transpose(train_answer)

                reward, _ = disc_step(sess,
                                      bucket_id,
                                      disc_model,
                                      train_query,
                                      train_answer,
                                      train_labels,
                                      forward_only=True)
                batch_reward += reward / gen_config.steps_per_checkpoint
                print("step_reward: ", reward)

                gan_adjusted_loss, gen_step_loss, _ = gen_model.step(
                    sess,
                    encoder_inputs,
                    decoder_inputs_negative,
                    target_weights,
                    bucket_id,
                    forward_only=False,
                    reward=reward,
                    up_reward=True,
                    debug=True)
                gen_loss += gen_step_loss / gen_config.steps_per_checkpoint

                print("gen_step_loss: ", gen_step_loss)
                print("gen_step_adjusted_loss: ", gan_adjusted_loss)

                t_adjusted_loss, t_step_loss, a = gen_model.step(
                    sess,
                    encoder_inputs,
                    decoder_inputs,
                    target_weights,
                    bucket_id,
                    forward_only=False)
                t_loss += t_step_loss / (G_STEPS *
                                         gen_config.steps_per_checkpoint)

                print("t_step_loss: ", t_step_loss)
                print("t_adjusted_loss", t_adjusted_loss)

            if current_step % gen_config.steps_per_checkpoint == 0:

                step_time += (time.time() -
                              start_time) / gen_config.steps_per_checkpoint

                print(
                    "current_steps: %d, step time: %.4f, disc_loss: %.3f, gen_loss: %.3f, t_loss: %.3f, reward: %.3f "
                    % (current_step, step_time, disc_loss, gen_loss, t_loss,
                       batch_reward))

                disc_loss_value = disc_loss_summary.value.add()
                disc_loss_value.tag = disc_config.name_loss
                disc_loss_value.simple_value = float(disc_loss)
                disc_writer.add_summary(disc_loss_summary,
                                        int(sess.run(disc_model.global_step)))

                gen_global_steps = sess.run(gen_model.global_step)
                gen_loss_value = gen_loss_summary.value.add()
                gen_loss_value.tag = gen_config.name_loss
                gen_loss_value.simple_value = float(gen_loss)
                t_loss_value = gen_loss_summary.value.add()
                t_loss_value.tag = gen_config.teacher_loss
                t_loss_value.simple_value = float(t_loss)
                batch_reward_value = gen_loss_summary.value.add()
                batch_reward_value.tag = gen_config.reward_name
                batch_reward_value.simple_value = float(batch_reward)
                gen_writer.add_summary(gen_loss_summary, int(gen_global_steps))

                if current_step % (gen_config.steps_per_checkpoint * 4) == 0:
                    print("current_steps: %d, save disc model" % current_step)
                    disc_ckpt_dir = os.path.abspath(
                        os.path.join(disc_config.train_dir, "checkpoints"))
                    if not os.path.exists(disc_ckpt_dir):
                        os.makedirs(disc_ckpt_dir)
                    disc_model_path = os.path.join(disc_ckpt_dir, "disc.model")
                    disc_model.saver.save(sess,
                                          disc_model_path,
                                          global_step=disc_model.global_step)

                    print("current_steps: %d, save gen model" % current_step)
                    gen_ckpt_dir = os.path.abspath(
                        os.path.join(gen_config.train_dir, "checkpoints"))
                    if not os.path.exists(gen_ckpt_dir):
                        os.makedirs(gen_ckpt_dir)
                    gen_model_path = os.path.join(gen_ckpt_dir, "gen.model")
                    gen_model.saver.save(sess,
                                         gen_model_path,
                                         global_step=gen_model.global_step)

                step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
                sys.stdout.flush()
def al_train():
    with tf.Session() as sess:

        vocab, rev_vocab, dev_set, train_set = gens.prepare_data(gen_config)
        for set in train_set:
            print("al train len: ", len(set))

        train_bucket_sizes = [
            len(train_set[b]) for b in xrange(len(gen_config.buckets))
        ]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        disc_model = h_disc.create_model(sess, disc_config,
                                         disc_config.name_model)
        gen_model = gens.create_model(sess,
                                      gen_config,
                                      forward_only=False,
                                      name_scope=gen_config.name_model)

        current_step = 0
        step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
        gen_loss_summary = tf.Summary()
        disc_loss_summary = tf.Summary()

        gen_writer = tf.summary.FileWriter(gen_config.tensorboard_dir,
                                           sess.graph)
        disc_writer = tf.summary.FileWriter(disc_config.tensorboard_dir,
                                            sess.graph)

        while True:
            current_step += 1
            start_time = time.time()
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])
            # disc_config.max_len = gen_config.buckets[bucket_id][0] + gen_config.buckets[bucket_id][1]

            print(
                "==================Update Discriminator: %d====================="
                % current_step)
            # 1.Sample (X,Y) from real disc_data
            # print("bucket_id: %d" %bucket_id)
            encoder_inputs, decoder_inputs, target_weights, source_inputs, source_outputs = gen_model.get_batch(
                train_set, bucket_id, gen_config.batch_size)

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X)
            train_query, train_answer, train_labels = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                encoder_inputs,
                decoder_inputs,
                target_weights,
                bucket_id,
                mc_search=False)
            print(
                "==============================mc_search: False==================================="
            )
            if current_step % 200 == 0:
                print("train_query: ", len(train_query))
                print("train_answer: ", len(train_answer))
                print("train_labels: ", len(train_labels))
                for i in xrange(len(train_query)):
                    print("lable: ", train_labels[i])
                    print("train_answer_sentence: ", train_answer[i])
                    print(" ".join([
                        tf.compat.as_str(rev_vocab[output])
                        for output in train_answer[i]
                    ]))

            train_query = np.transpose(train_query)
            train_answer = np.transpose(train_answer)

            # 3.Update D using (X, Y ) as positive examples and(X, ^Y) as negative examples
            _, disc_step_loss = disc_step(sess,
                                          bucket_id,
                                          disc_model,
                                          train_query,
                                          train_answer,
                                          train_labels,
                                          forward_only=False)
            disc_loss += disc_step_loss / disc_config.steps_per_checkpoint
            #每一个更新下D模型,每200次更新下G模型
            print(
                "==================Update Generator: %d========================="
                % current_step)
            # 1.Sample (X,Y) from real disc_data
            update_gen_data = gen_model.get_batch(train_set, bucket_id,
                                                  gen_config.batch_size)
            encoder, decoder, weights, source_inputs, source_outputs = update_gen_data

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X) with Monte Carlo search
            train_query, train_answer, train_labels = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                encoder,
                decoder,
                weights,
                bucket_id,
                mc_search=True)

            print(
                "=============================mc_search: True===================================="
            )
            if current_step % 200 == 0:
                for i in xrange(len(train_query)):
                    print("lable: ", train_labels[i])
                    print(" ".join([
                        tf.compat.as_str(rev_vocab[output])
                        for output in train_answer[i]
                    ]))

            train_query = np.transpose(train_query)
            train_answer = np.transpose(train_answer)

            # 3.Compute Reward r for (X, ^Y ) using D.---based on Monte Carlo search
            reward, _ = disc_step(sess,
                                  bucket_id,
                                  disc_model,
                                  train_query,
                                  train_answer,
                                  train_labels,
                                  forward_only=True)
            batch_reward += reward / gen_config.steps_per_checkpoint
            print("step_reward: ", reward)

            # 4.Update G on (X, ^Y ) using reward r   #用poliy gradient更新G
            gan_adjusted_loss, gen_step_loss, _ = gen_model.step(
                sess,
                encoder,
                decoder,
                weights,
                bucket_id,
                forward_only=False,
                reward=reward,
                up_reward=True,
                debug=True)
            gen_loss += gen_step_loss / gen_config.steps_per_checkpoint

            print("gen_step_loss: ", gen_step_loss)
            print("gen_step_adjusted_loss: ", gan_adjusted_loss)

            # 5.Teacher-Forcing: Update G on (X, Y )   #用极大似然法更新G
            t_adjusted_loss, t_step_loss, a = gen_model.step(
                sess, encoder, decoder, weights, bucket_id, forward_only=False)
            t_loss += t_step_loss / gen_config.steps_per_checkpoint

            print("t_step_loss: ", t_step_loss)
            print("t_adjusted_loss", t_adjusted_loss)  # print("normal: ", a)

            if current_step % gen_config.steps_per_checkpoint == 0:

                step_time += (time.time() -
                              start_time) / gen_config.steps_per_checkpoint

                print(
                    "current_steps: %d, step time: %.4f, disc_loss: %.3f, gen_loss: %.3f, t_loss: %.3f, reward: %.3f"
                    % (current_step, step_time, disc_loss, gen_loss, t_loss,
                       batch_reward))

                disc_loss_value = disc_loss_summary.value.add()
                disc_loss_value.tag = disc_config.name_loss
                disc_loss_value.simple_value = float(disc_loss)
                disc_writer.add_summary(disc_loss_summary,
                                        int(sess.run(disc_model.global_step)))

                gen_global_steps = sess.run(gen_model.global_step)
                gen_loss_value = gen_loss_summary.value.add()
                gen_loss_value.tag = gen_config.name_loss
                gen_loss_value.simple_value = float(gen_loss)
                t_loss_value = gen_loss_summary.value.add()
                t_loss_value.tag = gen_config.teacher_loss
                t_loss_value.simple_value = float(t_loss)
                batch_reward_value = gen_loss_summary.value.add()
                batch_reward_value.tag = gen_config.reward_name
                batch_reward_value.simple_value = float(batch_reward)
                gen_writer.add_summary(gen_loss_summary, int(gen_global_steps))

                if current_step % (gen_config.steps_per_checkpoint * 2) == 0:
                    print("current_steps: %d, save disc model" % current_step)
                    disc_ckpt_dir = os.path.abspath(
                        os.path.join(disc_config.train_dir, "checkpoints"))
                    if not os.path.exists(disc_ckpt_dir):
                        os.makedirs(disc_ckpt_dir)
                    disc_model_path = os.path.join(disc_ckpt_dir, "disc.model")
                    disc_model.saver.save(sess,
                                          disc_model_path,
                                          global_step=disc_model.global_step)

                    print("current_steps: %d, save gen model" % current_step)
                    gen_ckpt_dir = os.path.abspath(
                        os.path.join(gen_config.train_dir, "checkpoints"))
                    if not os.path.exists(gen_ckpt_dir):
                        os.makedirs(gen_ckpt_dir)
                    gen_model_path = os.path.join(gen_ckpt_dir, "gen.model")
                    gen_model.saver.save(sess,
                                         gen_model_path,
                                         global_step=gen_model.global_step)

                step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
                sys.stdout.flush()
def al_train():
    with tf.Session() as sess:
        current_step = 1
        disc_model = h_disc.create_model(sess, disc_config)
        gen_model = gens.create_model(sess, gen_config)
        vocab, rev_vocab, dev_set, train_set = gens.prepare_data(gen_config)
        for set in train_set:
            print("train len: ", len(set))

        train_bucket_sizes = [
            len(train_set[b]) for b in xrange(len(gen_config.buckets))
        ]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        while True:
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])
            disc_config.max_len = gen_config.buckets[bucket_id][
                0] + gen_config.buckets[bucket_id][1]
            print(
                "===========================Update Discriminator================================"
            )
            # 1.Sample (X,Y) from real disc_data
            print("bucket_id: %d" % bucket_id)

            encoder_inputs, decoder_inputs, target_weights, source_inputs, source_outputs = gen_model.get_batch(
                train_set, bucket_id, gen_config.batch_size)
            print("source_inputs: ", len(source_inputs))
            print("source_outputs: ", len(source_outputs))
            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X)
            train_query, train_answer, train_labels = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                encoder_inputs,
                decoder_inputs,
                target_weights,
                bucket_id,
                mc_search=False)
            train_query = np.transpose(train_query)
            train_answer = np.transpose(train_answer)
            # 3.Update D using (X, Y ) as positive examples and(X, ^Y) as negative examples
            disc_step(sess,
                      bucket_id,
                      disc_model,
                      train_query,
                      train_answer,
                      train_labels,
                      forward_only=False)

            print(
                "===============================Update Generator================================"
            )
            # 1.Sample (X,Y) from real disc_data
            update_gen_data = gen_model.get_batch(train_set, bucket_id,
                                                  gen_config.batch_size)
            encoder, decoder, weights, source_inputs, source_outputs = update_gen_data

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X) with Monte Carlo search
            train_query, train_answer, train_labels = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                encoder,
                decoder,
                weights,
                bucket_id,
                mc_search=True)
            train_query = np.transpose(train_query)
            train_answer = np.transpose(train_answer)
            # 3.Compute Reward r for (X, ^Y ) using D.---based on Monte Carlo search
            reward = disc_step(sess,
                               bucket_id,
                               disc_model,
                               train_query,
                               train_answer,
                               train_labels,
                               forward_only=True)

            # 4.Update G on (X, ^Y ) using reward r
            _, loss, a = gen_model.step(sess,
                                        encoder,
                                        decoder,
                                        weights,
                                        bucket_id,
                                        forward_only=False,
                                        reward=reward,
                                        debug=True)
            print("up_reward: ", a)

            # 5.Teacher-Forcing: Update G on (X, Y )
            _, loss, a = gen_model.step(sess,
                                        encoder,
                                        decoder,
                                        weights,
                                        bucket_id,
                                        forward_only=False)
            print("loss: ", loss)
            print("normal: ", a)

            if current_step % steps_per_checkpoint == 0:
                print("save disc model")
                disc_ckpt_dir = os.path.abspath(
                    os.path.join(disc_config.data_dir, "checkpoints"))
                if not os.path.exists(disc_ckpt_dir):
                    os.makedirs(disc_ckpt_dir)
                disc_model_path = os.path.join(disc_ckpt_dir, "disc.model")
                disc_model.saver.save(sess,
                                      disc_model_path,
                                      global_step=disc_model.global_step)

                print("save gen model")
                gen_ckpt_dir = os.path.abspath(
                    os.path.join(gen_config.data_dir, "checkpoints"))
                if not os.path.exists(gen_ckpt_dir):
                    os.makedirs(gen_ckpt_dir)
                gen_model_path = os.path.join(gen_ckpt_dir, "gen.model")
                gen_model.saver.save(sess,
                                     gen_model_path,
                                     global_step=gen_model.global_step)
            current_step += 1