コード例 #1
0
def al_train():
    with tf.Session() as sess:

        vocab, rev_vocab, dev_set, train_set = gens.prepare_data(gen_config)
        for set in train_set:
            print("al train len: ", len(set))

        train_bucket_sizes = [
            len(train_set[b]) for b in xrange(len(gen_config.buckets))
        ]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        disc_model = h_disc.create_model(sess, disc_config,
                                         disc_config.name_model)
        gen_model = gens.create_model(sess,
                                      gen_config,
                                      forward_only=False,
                                      name_scope=gen_config.name_model)

        current_step = 0
        step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
        gen_loss_summary = tf.Summary()
        disc_loss_summary = tf.Summary()

        gen_writer = tf.summary.FileWriter(gen_config.tensorboard_dir,
                                           sess.graph)
        disc_writer = tf.summary.FileWriter(disc_config.tensorboard_dir,
                                            sess.graph)

        while True:
            current_step += 1
            start_time = time.time()
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])
            # disc_config.max_len = gen_config.buckets[bucket_id][0] + gen_config.buckets[bucket_id][1]

            print(
                "==================Update Discriminator: %d====================="
                % current_step)
            # 1.Sample (X,Y) from real disc_data
            # print("bucket_id: %d" %bucket_id)
            encoder_inputs, decoder_inputs, target_weights, source_inputs, source_outputs = gen_model.get_batch(
                train_set, bucket_id, gen_config.batch_size)

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X)
            train_query, train_answer, train_labels = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                encoder_inputs,
                decoder_inputs,
                target_weights,
                bucket_id,
                mc_search=False)
            print(
                "==============================mc_search: False==================================="
            )
            if current_step % 200 == 0:
                print("train_query: ", len(train_query))
                print("train_answer: ", len(train_answer))
                print("train_labels: ", len(train_labels))
                for i in xrange(len(train_query)):
                    print("lable: ", train_labels[i])
                    print("train_answer_sentence: ", train_answer[i])
                    print(" ".join([
                        tf.compat.as_str(rev_vocab[output])
                        for output in train_answer[i]
                    ]))

            train_query = np.transpose(train_query)
            train_answer = np.transpose(train_answer)

            # 3.Update D using (X, Y ) as positive examples and(X, ^Y) as negative examples
            _, disc_step_loss = disc_step(sess,
                                          bucket_id,
                                          disc_model,
                                          train_query,
                                          train_answer,
                                          train_labels,
                                          forward_only=False)
            disc_loss += disc_step_loss / disc_config.steps_per_checkpoint
            #每一个更新下D模型,每200次更新下G模型
            print(
                "==================Update Generator: %d========================="
                % current_step)
            # 1.Sample (X,Y) from real disc_data
            update_gen_data = gen_model.get_batch(train_set, bucket_id,
                                                  gen_config.batch_size)
            encoder, decoder, weights, source_inputs, source_outputs = update_gen_data

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X) with Monte Carlo search
            train_query, train_answer, train_labels = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                encoder,
                decoder,
                weights,
                bucket_id,
                mc_search=True)

            print(
                "=============================mc_search: True===================================="
            )
            if current_step % 200 == 0:
                for i in xrange(len(train_query)):
                    print("lable: ", train_labels[i])
                    print(" ".join([
                        tf.compat.as_str(rev_vocab[output])
                        for output in train_answer[i]
                    ]))

            train_query = np.transpose(train_query)
            train_answer = np.transpose(train_answer)

            # 3.Compute Reward r for (X, ^Y ) using D.---based on Monte Carlo search
            reward, _ = disc_step(sess,
                                  bucket_id,
                                  disc_model,
                                  train_query,
                                  train_answer,
                                  train_labels,
                                  forward_only=True)
            batch_reward += reward / gen_config.steps_per_checkpoint
            print("step_reward: ", reward)

            # 4.Update G on (X, ^Y ) using reward r   #用poliy gradient更新G
            gan_adjusted_loss, gen_step_loss, _ = gen_model.step(
                sess,
                encoder,
                decoder,
                weights,
                bucket_id,
                forward_only=False,
                reward=reward,
                up_reward=True,
                debug=True)
            gen_loss += gen_step_loss / gen_config.steps_per_checkpoint

            print("gen_step_loss: ", gen_step_loss)
            print("gen_step_adjusted_loss: ", gan_adjusted_loss)

            # 5.Teacher-Forcing: Update G on (X, Y )   #用极大似然法更新G
            t_adjusted_loss, t_step_loss, a = gen_model.step(
                sess, encoder, decoder, weights, bucket_id, forward_only=False)
            t_loss += t_step_loss / gen_config.steps_per_checkpoint

            print("t_step_loss: ", t_step_loss)
            print("t_adjusted_loss", t_adjusted_loss)  # print("normal: ", a)

            if current_step % gen_config.steps_per_checkpoint == 0:

                step_time += (time.time() -
                              start_time) / gen_config.steps_per_checkpoint

                print(
                    "current_steps: %d, step time: %.4f, disc_loss: %.3f, gen_loss: %.3f, t_loss: %.3f, reward: %.3f"
                    % (current_step, step_time, disc_loss, gen_loss, t_loss,
                       batch_reward))

                disc_loss_value = disc_loss_summary.value.add()
                disc_loss_value.tag = disc_config.name_loss
                disc_loss_value.simple_value = float(disc_loss)
                disc_writer.add_summary(disc_loss_summary,
                                        int(sess.run(disc_model.global_step)))

                gen_global_steps = sess.run(gen_model.global_step)
                gen_loss_value = gen_loss_summary.value.add()
                gen_loss_value.tag = gen_config.name_loss
                gen_loss_value.simple_value = float(gen_loss)
                t_loss_value = gen_loss_summary.value.add()
                t_loss_value.tag = gen_config.teacher_loss
                t_loss_value.simple_value = float(t_loss)
                batch_reward_value = gen_loss_summary.value.add()
                batch_reward_value.tag = gen_config.reward_name
                batch_reward_value.simple_value = float(batch_reward)
                gen_writer.add_summary(gen_loss_summary, int(gen_global_steps))

                if current_step % (gen_config.steps_per_checkpoint * 2) == 0:
                    print("current_steps: %d, save disc model" % current_step)
                    disc_ckpt_dir = os.path.abspath(
                        os.path.join(disc_config.train_dir, "checkpoints"))
                    if not os.path.exists(disc_ckpt_dir):
                        os.makedirs(disc_ckpt_dir)
                    disc_model_path = os.path.join(disc_ckpt_dir, "disc.model")
                    disc_model.saver.save(sess,
                                          disc_model_path,
                                          global_step=disc_model.global_step)

                    print("current_steps: %d, save gen model" % current_step)
                    gen_ckpt_dir = os.path.abspath(
                        os.path.join(gen_config.train_dir, "checkpoints"))
                    if not os.path.exists(gen_ckpt_dir):
                        os.makedirs(gen_ckpt_dir)
                    gen_model_path = os.path.join(gen_ckpt_dir, "gen.model")
                    gen_model.saver.save(sess,
                                         gen_model_path,
                                         global_step=gen_model.global_step)

                step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
                sys.stdout.flush()
コード例 #2
0
def al_train():
    vocab, rev_vocab, dev_set, train_set = gens.prepare_data(gen_config)

    seq2seq = torch.load('pre_seq2seq.pth')
    optim_seq2seq = optim.Adam(seq2seq.parameters(), lr=gen_config.lr)
    hrnn = torch.load('pre_hrnn.pth')
    optim_hrnn = optim.Adam(hrnn.parameters(), lr=disc_config.lr)
    # hrnn, optim_hrnn = h_disc.create_model(disc_config)
    # seq2seq, optim_seq2seq = gens.create_model(gen_config)

    current_step = 0
    while True:
        current_step += 1
        start_time = time.time()
        print(
            "==================Update Discriminator: %d=====================" %
            current_step)
        for i in range(disc_config.disc_steps):
            # 1.Sample (X,Y) from real disc_data
            encoder_inputs, decoder_inputs, source_inputs, target_inputs = gens.getbatch(
                train_set, gen_config.batch_size, gen_config.maxlen)

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X)
            train_query, train_answer, train_labels, train_answer_gen = disc_train_data(
                seq2seq,
                vocab,
                source_inputs,
                target_inputs,
                encoder_inputs,
                decoder_inputs,
                mc_search=False)
            print(
                "==============================mc_search: False==================================="
            )
            if current_step % 200 == 0:
                print("train_query: ", len(train_query))
                print("train_answer: ", len(train_answer))
                print("train_labels: ", len(train_labels))
                for i in xrange(len(train_query)):
                    print("lable: ", train_labels[i])
                    print("train_answer_sentence: ", train_answer[i])
                    print(" ".join(
                        [rev_vocab[output] for output in train_answer[i]]))

            # 3.Update D using (X, Y ) as positive examples and(X, ^Y) as negative examples
            step_loss = h_disc.disc_step(hrnn, optim_hrnn, disc_config,
                                         train_query, train_answer,
                                         train_labels)
            print("update discriminator loss is:", step_loss)

        for i in range(gen_config.gen_steps):
            print(
                "==================Update Generator: %d========================="
                % current_step)
            # 1.Sample (X,Y) from real disc_data
            encoder_inputs, decoder_inputs, source_inputs, target_inputs = gens.getbatch(
                train_set, gen_config.batch_size, gen_config.maxlen)

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X)
            train_query, train_answer, train_labels, train_answer_gen = disc_train_data(
                seq2seq,
                vocab,
                source_inputs,
                target_inputs,
                encoder_inputs,
                decoder_inputs,
                mc_search=False)
            train_query_neg = []
            train_answer_neg = []
            train_labels_neg = []
            train_query_pos = []
            train_answer_pos = []
            train_labels_pos = []
            for j in range(len(train_labels)):
                if train_labels[j] == 0:
                    train_query_neg.append(train_query[j])
                    train_answer_neg.append(train_answer[j])
                    train_labels_neg.append(0)
                else:
                    train_query_pos.append(train_query[j])
                    train_answer_pos.append(train_answer[j])
                    train_labels_pos.append(1)

            # 3.Compute Reward r for (X, ^Y ) using D.---based on Monte Carlo search
            reward = h_disc.disc_reward_step(hrnn, train_query_neg,
                                             train_answer_neg)
            # 4.update G on (X, ^Y) using reward r
            loss_reward = gens.train_with_reward(gen_config, seq2seq,
                                                 optim_seq2seq, reward,
                                                 train_query_neg,
                                                 train_answer_gen)
            # 5.Teacher-Forcing: update G on (X, Y)
            loss = gens.teacher_forcing(gen_config, seq2seq, optim_seq2seq,
                                        encoder_inputs, decoder_inputs)
            print(
                "update generate loss, reward is %f, loss_reward is %f, loss is %f"
                % (np.mean(reward), loss_reward, loss))
        end_time = time.time()
        print("step %d spend time: %f" % (current_step, end_time - start_time))

        if current_step % 1000 == 0:
            torch.save(seq2seq, './seq2seq.pth')
            torch.save(hrnn, './hrnn.pth')
コード例 #3
0
def al_train():
    tf_config = tf.ConfigProto(allow_soft_placement=True, device_count={'GPU': 1})
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
    run_metadata = tf.RunMetadata()
    run_time=(run_options, run_metadata)
    np.random.seed(2)
    random.seed(2)
    tf.set_random_seed(2)
    # tf_config.gpu_options.per_process_gpu_memory_fraction = 0.7
    # sess_g = tf.Session(config=tf_config)
    # sess_r = tf.Session(config=tf_config)
    with tf.Session(config=tf_config) as sess_public:
        # sess_pair = (sess_g, sess_r)
        vocab, rev_vocab, test_set, dev_set, train_set = gens.prepare_data(gen_config)
        gen_config.vocab_size = len(rev_vocab)
        print("vocab sizei: {}".format(gen_config.vocab_size))
        for set in train_set:
            print("training set len: ", len(set))
        for set in test_set:
            print("testing set len: ", len(set))

        train_bucket_sizes = [len(train_set[b]) for b in xrange(len(gen_config.buckets))]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
                               for i in xrange(len(train_bucket_sizes))]
        g1 = tf.Graph()
        with g1.as_default():
            sess_r = tf.Session(config=tf_config, graph=g1)
            disc_model = r_disc.create_model(sess_r, disc_config, disc_config.name_model, vocab)
        g2 = tf.Graph()
        with g2.as_default():
            sess_g = tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=g2)
            gen_model = gens.create_model(sess_g, gen_config, forward_only=False, name_scope=gen_config.name_model,
                                      word2id=vocab)
        sess_pair = (sess_g, sess_r)
        # eval_model = eval_disc.create_model(sess, evl_config, evl_config.name_model, vocab)
        current_step = 0
        step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
        disc_step = 10
        if gen_config.continue_train:
            disc_step = 5
        reward_base = 0
        reward_history = np.zeros(100)
        while True:
            start_time = time.time()
            random_number_01 = np.random.random_sample()
            bucket_id = min([i for i in xrange(len(train_buckets_scale))
                             if train_buckets_scale[i] > random_number_01])
            print("Sampled bucket ID: {}".format(bucket_id))
            # disc_config.max_len = gen_config.buckets[bucket_id][0] + gen_config.buckets[bucket_id][1]
            # b_query, b_gen = train_set[bucket_id], dev_set[bucket_id]
            '''
            if current_step % 10 == 0 and current_step != 0 or (current_step ==0 and gen_config.testing):
                print("==========Evaluate dev set: %d==========" % current_step)
                bleu_score = evaluate_gan(sess=sess_pair,
                                          gen_model=gen_model,
                                          eval_model=None,
                                          gen_config=gen_config,
                                          disc_model=disc_model,
                                          dataset=test_set,
                                          buckets=gen_config.buckets,
                                          rev_vocab=rev_vocab)

                print("Bleu-1 score on dev set: %.4f" % bleu_score[0])
                print("Bleu-2 score on dev set: %.4f" % bleu_score[1])
                print("Bleu-3 score on dev set: %.4f" % bleu_score[2])
            '''
            if gen_config.testing:
                break

            print("==========Update Discriminator: %d==========" % current_step)
            disc_step_loss = train_disc(sess=sess_pair,
                                        gen_model=gen_model,
                                        disc_model=disc_model,
                                        train_set=train_set,
                                        bucket_id=bucket_id,
                                        rev_vocab=rev_vocab,
                                        current_step=current_step,
                                        disc_freq=disc_step)
            disc_step = 5
            disc_loss += disc_step_loss / disc_config.steps_per_checkpoint

            disc_time = time.time()
            print("disc training time %.2f" % (disc_time - start_time))

            print("==========Update Generator: %d==========" % current_step)

            update_gen_data = gen_model.get_batch(train_set, bucket_id, gen_config.batch_size)
            encoder_real, decoder_real, weights_real, source_inputs_real, source_outputs_real, target_real = update_gen_data

            # 2.Sample (X, ^Y) through ^Y ~ G(*|X) with MC
            # answers have no EOS_ID
            sampled_query, sampled_answer, _ = sample_relpy_with_x(sess=sess_g,
                                                                   gen_model=gen_model,
                                                                   source_inputs=source_inputs_real,
                                                                   source_outputs=source_outputs_real,
                                                                   encoder_inputs=encoder_real,
                                                                   decoder_inputs=decoder_real,
                                                                   target_weights=weights_real,
                                                                   target_input=target_real,
                                                                   bucket_id=bucket_id,
                                                                   mc_position=0)
            sample_time = time.time()
            print("sampling time %.2f" % (sample_time - disc_time))
            gen_sampled_batch = gen_model.gen_batch_preprocess(query=sampled_query,
                                                               answer=sampled_answer,
                                                               bucket_id=bucket_id,
                                                               batch_size=gen_config.batch_size)
            # source answers have no EOS_ID
            encoder_sampled, decoder_sampled, weights_sampled, source_inputs_sampled, source_outputs_sampled, target_sampled = gen_sampled_batch

            # 3. MC search to approximate the reward at each position for the sampled reply
            mc_samples, mc_reward, mc_adjusted_word = mc_sampler_fast(sess=sess_pair,
                                                                      gen_model=gen_model,
                                                                      source_inputs=source_inputs_sampled,
                                                                      source_outputs=source_outputs_sampled,
                                                                      encoder_inputs=encoder_sampled,
                                                                      decoder_inputs=decoder_sampled,
                                                                      target_weights=weights_sampled,
                                                                      target_inputs=target_sampled,
                                                                      bucket_id=bucket_id,
                                                                      disc_model=disc_model,
                                                                      reward_base=reward_base,
                                                                      run_hist=run_time)
            reward_history[current_step%100] = np.sum(mc_reward) / np.count_nonzero(mc_reward)
            if current_step<100:
                reward_base = np.sum(reward_history) / (current_step + 1)
            else:
                reward_base = np.sum(reward_history) / 100

            mc_time = time.time()
            print("mc time %.2f" % (mc_time - sample_time))

            batch_reward_step = np.mean(mc_reward[0])
            batch_reward_step_first_line = mc_reward[:, 0]
            # print("step_reward: ", np.mean(mc_reward[-1]))

            # 4.Update G on (X, ^Y ) using mc_reward
            gan_adjusted_loss, gen_step_loss, _, _ = gen_model.step(sess_g,
                                                                    encoder_sampled,
                                                                    decoder_sampled,
                                                                    target_sampled,
                                                                    weights_sampled,
                                                                    bucket_id,
                                                                    forward_only=False,
                                                                    reward=mc_adjusted_word,
                                                                    up_reward=True,
                                                                    debug=True
                                                                    )
            print("step_reward: ", batch_reward_step_first_line)
            print("gen_step_loss: ", gen_step_loss)
            print("gen_step_adjusted_loss: ", gan_adjusted_loss)
            batch_reward += batch_reward_step / gen_config.steps_per_checkpoint
            gen_loss += gen_step_loss / gen_config.steps_per_checkpoint

            gen_time = time.time()
            print("gen update time %.2f" % (gen_time - mc_time))
            print("Gen training time %.2f" % (gen_time - disc_time))

            if gen_config.teacher_forcing:
                print("==========Teacher-Forcing: %d==========" % current_step)
                # encoder_real, decoder_real, weights_real = true_dialog
                reward_god = []
                reward_arr = np.array(weights_real) - 0.0
                for idx in range(len(weights_real)):
                    reward_god.append(np.sum(reward_arr[idx:], axis=0))
                reward_god = np.array(reward_god).tolist()
                t_adjusted_loss, t_step_loss, _, a = gen_model.step(sess_g,
                                                                    encoder_real,
                                                                    decoder_real,
                                                                    target_real,
                                                                    weights_real,
                                                                    bucket_id,
                                                                    reward=reward_god,
                                                                    teacher_forcing=True,
                                                                    forward_only=False)
                t_loss += t_step_loss / gen_config.steps_per_checkpoint
                print("t_step_loss: ", t_step_loss)
                print("t_adjusted_loss", t_adjusted_loss)  # print("normal: ", a)
                teacher_time = time.time()
                print("teacher time %.2f" % (teacher_time - gen_time))

            if current_step % gen_config.steps_per_checkpoint == 0:

                step_time += (time.time() - start_time) / gen_config.steps_per_checkpoint

                print("current_steps: %d, step time: %.4f, disc_loss: %.3f, gen_loss: %.3f, t_loss: %.3f, reward: %.3f"
                      % (current_step, step_time, disc_loss, gen_loss, t_loss, batch_reward))

                if current_step % (gen_config.steps_per_checkpoint * 1) == 0:
                    print("current_steps: %d, save disc model" % current_step)
                    disc_ckpt_dir = os.path.abspath(
                        os.path.join(disc_config.model_dir, 'disc_model',
                                     "data-{}_pre_embed-{}_ent-{}_exp-{}_teacher-{}".format(
                                         disc_config.data_id,
                                         disc_config.pre_embed,
                                         disc_config.ent_weight,
                                         disc_config.exp_id,
                                         disc_config.teacher_forcing)))
                    if not os.path.exists(disc_ckpt_dir):
                        os.makedirs(disc_ckpt_dir)
                    disc_model_path = os.path.join(disc_ckpt_dir, "disc.model")
                    disc_model.saver.save(sess_r, disc_model_path, global_step=disc_model.global_step)

                    print("current_steps: %d, save gen model" % current_step)
                    gen_ckpt_dir = os.path.abspath(
                        os.path.join(gen_config.model_dir, 'gen_model',
                                     "data-{}_pre_embed-{}_ent-{}_exp-{}_teacher-{}".format(
                                         gen_config.data_id,
                                         gen_config.pre_embed,
                                         gen_config.ent_weight,
                                         gen_config.exp_id,
                                         gen_config.teacher_forcing)))
                    if not os.path.exists(gen_ckpt_dir):
                        os.makedirs(gen_ckpt_dir)
                    gen_model_path = os.path.join(gen_ckpt_dir, "gen.model")
                    gen_model.saver.save(sess_g, gen_model_path, global_step=gen_model.global_step)

                    step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0
                    sys.stdout.flush()

            current_step += 1
def al_train():
    with tf.Session() as sess:
        current_step = 1
        disc_model = h_disc.create_model(sess, disc_config)
        gen_model = gens.create_model(sess, gen_config)
        vocab, rev_vocab, dev_set, train_set = gens.prepare_data(gen_config)
        for set in train_set:
            print("train len: ", len(set))

        train_bucket_sizes = [
            len(train_set[b]) for b in xrange(len(gen_config.buckets))
        ]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        while True:
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])
            disc_config.max_len = gen_config.buckets[bucket_id][
                0] + gen_config.buckets[bucket_id][1]
            print(
                "===========================Update Discriminator================================"
            )
            # 1.Sample (X,Y) from real disc_data
            print("bucket_id: %d" % bucket_id)

            encoder_inputs, decoder_inputs, target_weights, source_inputs, source_outputs = gen_model.get_batch(
                train_set, bucket_id, gen_config.batch_size)
            print("source_inputs: ", len(source_inputs))
            print("source_outputs: ", len(source_outputs))
            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X)
            train_query, train_answer, train_labels = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                encoder_inputs,
                decoder_inputs,
                target_weights,
                bucket_id,
                mc_search=False)
            train_query = np.transpose(train_query)
            train_answer = np.transpose(train_answer)
            # 3.Update D using (X, Y ) as positive examples and(X, ^Y) as negative examples
            disc_step(sess,
                      bucket_id,
                      disc_model,
                      train_query,
                      train_answer,
                      train_labels,
                      forward_only=False)

            print(
                "===============================Update Generator================================"
            )
            # 1.Sample (X,Y) from real disc_data
            update_gen_data = gen_model.get_batch(train_set, bucket_id,
                                                  gen_config.batch_size)
            encoder, decoder, weights, source_inputs, source_outputs = update_gen_data

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X) with Monte Carlo search
            train_query, train_answer, train_labels = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                encoder,
                decoder,
                weights,
                bucket_id,
                mc_search=True)
            train_query = np.transpose(train_query)
            train_answer = np.transpose(train_answer)
            # 3.Compute Reward r for (X, ^Y ) using D.---based on Monte Carlo search
            reward = disc_step(sess,
                               bucket_id,
                               disc_model,
                               train_query,
                               train_answer,
                               train_labels,
                               forward_only=True)

            # 4.Update G on (X, ^Y ) using reward r
            _, loss, a = gen_model.step(sess,
                                        encoder,
                                        decoder,
                                        weights,
                                        bucket_id,
                                        forward_only=False,
                                        reward=reward,
                                        debug=True)
            print("up_reward: ", a)

            # 5.Teacher-Forcing: Update G on (X, Y )
            _, loss, a = gen_model.step(sess,
                                        encoder,
                                        decoder,
                                        weights,
                                        bucket_id,
                                        forward_only=False)
            print("loss: ", loss)
            print("normal: ", a)

            if current_step % steps_per_checkpoint == 0:
                print("save disc model")
                disc_ckpt_dir = os.path.abspath(
                    os.path.join(disc_config.data_dir, "checkpoints"))
                if not os.path.exists(disc_ckpt_dir):
                    os.makedirs(disc_ckpt_dir)
                disc_model_path = os.path.join(disc_ckpt_dir, "disc.model")
                disc_model.saver.save(sess,
                                      disc_model_path,
                                      global_step=disc_model.global_step)

                print("save gen model")
                gen_ckpt_dir = os.path.abspath(
                    os.path.join(gen_config.data_dir, "checkpoints"))
                if not os.path.exists(gen_ckpt_dir):
                    os.makedirs(gen_ckpt_dir)
                gen_model_path = os.path.join(gen_ckpt_dir, "gen.model")
                gen_model.saver.save(sess,
                                     gen_model_path,
                                     global_step=gen_model.global_step)
            current_step += 1
コード例 #5
0
def al_train():
    gen_config.batch_size = 1
    with tf.Session() as sess:
        disc_model = discs.create_model(sess, disc_config, is_training=True)
        gen_model = gens.create_model(sess, gen_config, forward_only=True)
        vocab, rev_vocab, dev_set, train_set = gens.prepare_data(gen_config)
        train_bucket_sizes = [
            len(train_set[b]) for b in xrange(len(gen_config.buckets))
        ]
        train_total_size = float(sum(train_bucket_sizes))
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        while True:
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])

            print(
                "===========================Update Discriminator================================"
            )
            # 1.Sample (X,Y) from real data
            _, _, _, source_inputs, source_outputs = gen_model.get_batch(
                train_set, bucket_id, 0)
            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X)
            train_inputs, train_labels, train_masks, _ = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                mc_search=False)
            # 3.Update D using (X, Y ) as positive examples and(X, ^Y) as negative examples
            disc_step(sess, disc_model, train_inputs, train_labels,
                      train_masks)

            print(
                "===============================Update Generator================================"
            )
            # 1.Sample (X,Y) from real data
            update_gen_data = gen_model.get_batch(train_set, bucket_id, 0)
            encoder, decoder, weights, source_inputs, source_outputs = update_gen_data

            # 2.Sample (X,Y) and (X, ^Y) through ^Y ~ G(*|X) with Monte Carlo search
            train_inputs, train_labels, train_masks, responses = disc_train_data(
                sess,
                gen_model,
                vocab,
                source_inputs,
                source_outputs,
                mc_search=True)
            # 3.Compute Reward r for (X, ^Y ) using D.---based on Monte Carlo search
            reward = disc_step(sess, disc_model, train_inputs, train_labels,
                               train_masks)

            # 4.Update G on (X, ^Y ) using reward r
            dec_gen = responses[0][:gen_config.buckets[bucket_id][1]]
            if len(dec_gen) < gen_config.buckets[bucket_id][1]:
                dec_gen = dec_gen + [0] * (gen_config.buckets[bucket_id][1] -
                                           len(dec_gen))
            dec_gen = np.reshape(dec_gen, (-1, 1))
            gen_model.step(sess,
                           encoder,
                           dec_gen,
                           weights,
                           bucket_id,
                           forward_only=False,
                           up_reward=True,
                           reward=reward,
                           debug=True)

            # 5.Teacher-Forcing: Update G on (X, Y )
            _, loss, _ = gen_model.step(sess,
                                        encoder,
                                        decoder,
                                        weights,
                                        bucket_id,
                                        forward_only=False,
                                        up_reward=False)
            print("loss: ", loss)

        #add checkpoint
        checkpoint_dir = os.path.abspath(
            os.path.join(disc_config.out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "disc.model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        pass