コード例 #1
0
ファイル: trainer.py プロジェクト: aichunks/NLP
def pretrain_discriminator(discriminator, sess):
    print("Pretrain discriminator...")
    dis_train_loader = Dataloader(FLAGS.dis_batch_size, FLAGS.vocab_size)
    dis_val_loader = Dataloader(FLAGS.dis_batch_size, FLAGS.vocab_size)
    pretrain_dis_data = np.load(FLAGS.pretrain_dis_data_path)
    pos_summary = pretrain_dis_data['pos_summary_idx']
    neg_summary = pretrain_dis_data['neg_summary_idx']
    assert len(pos_summary) == len(neg_summary)

    train_max_epoch = 20  # max training epochs
    val_num = 1000  # number of validation samples
    pos_train = []
    neg_train = []
    pos_val = []
    neg_val = []
    val_select = random.sample(list(range(0, len(pos_summary))), val_num)
    for i in range(len(pos_summary)):
        if i in val_select:
            pos_val.append(pos_summary[i][:FLAGS.max_dec_steps])
            neg_val.append(neg_summary[i][:FLAGS.max_dec_steps])
        else:
            pos_train.append(pos_summary[i][:FLAGS.max_dec_steps])
            neg_train.append(neg_summary[i][:FLAGS.max_dec_steps])
    print("length train:", len(pos_train))
    print("length val:", len(pos_val))

    for epoch in tqdm(range(train_max_epoch)):
        # training process
        dis_train_loader.load_data(pos_train, neg_train)
        dis_train_loader.reset_pointer()
        for it in range(dis_train_loader.num_batch):
            x_batch, y_batch = dis_train_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.input_y: y_batch,
                discriminator.dropout_keep_prob: 0.5
            }
            sess.run(discriminator.train_op, feed)

        # validation process
        dis_val_loader.load_data(pos_val, neg_val)
        dis_val_loader.reset_pointer()
        acc_list = []
        for it in range(dis_val_loader.num_batch):
            x_batch, y_batch = dis_val_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.input_y: y_batch,
                discriminator.dropout_keep_prob: 1.0
            }
            pred = sess.run(discriminator.predictions, feed)
            target = np.where(
                np.array(y_batch) == 1)[-1]  # np.concatenate(y_batch)
            acc_list.append(accuracy_score(y_pred=pred, y_true=target))
        eval_acc = np.mean(acc_list)
        print('pretrain epoch:{}, eval accuracy: {}'.format(epoch, eval_acc))
コード例 #2
0
ファイル: trainer.py プロジェクト: aichunks/NLP
def adversarial_train(generator, discriminator, generator_batcher,
                      discriminator_batcher, summary_writer,
                      sess_context_manager):
    print("Start adversarial training...")
    with sess_context_manager as sess:
        D_rewards = np.zeros((FLAGS.batch_size, FLAGS.max_dec_steps))
        rouge_rewards = np.zeros((FLAGS.batch_size, 1))

        while True:
            # Train generator for one step
            print("Start to train generator...")
            batch = generator_batcher.next_batch()
            batch.batch_reward = D_rewards
            batch.batch_rouge_reward = rouge_rewards
            t0 = time.time()
            result_train = generator.run_train_step(sess, batch)
            loss = result_train['loss']
            summaries = result_train['summaries']
            train_step = result_train['global_step']
            summary_writer.add_summary(summaries, train_step)
            print("train step: %d train loss: %.3f time: %.3fs" %
                  (train_step, loss, time.time() - t0))

            rouge_rewards = []
            target_token = batch.target_batch
            output_sample_token = np.transpose(
                np.squeeze(result_train['output_sample_token']))
            output_argmax_token = np.transpose(
                np.squeeze(result_train['output_summary_token']))
            rouge = Rouge()
            for target, sample, argmax in zip(target_token,
                                              output_sample_token,
                                              output_argmax_token):
                target_ = remove_eos(target)
                sample_ = remove_eos(sample)
                argmax_ = remove_eos(argmax)
                if len(argmax_) > 0:
                    r_baseline = rouge.get_scores(argmax_,
                                                  target_)[0]["rouge-l"]["f"]
                else:
                    r_baseline = 0
                if len(sample_) > 0:
                    r_sample = rouge.get_scores(sample_,
                                                target_)[0]["rouge-l"]["f"]
                else:
                    r_sample = 0
                #print("r_baseline:", r_baseline)
                #print("r_sample:", r_sample)
                rouge_rewards.append(r_baseline - r_sample)
            rouge_rewards = np.reshape(rouge_rewards, [FLAGS.batch_size, 1])
            print("RL reward for rouge-L: %.3f" % np.mean(rouge_rewards))

            print("running rollout step...")
            t0 = time.time()
            result_rollout = generator.run_rollout_step(sess, batch)
            rollout_output = result_rollout[
                'rollout_token']  # shape [rollout_num, seqlen(this is number of roll), batch_size, seq_len]
            print("rollout step: %.3fs" % (time.time() - t0))

            # calculate D_reward
            print("start to calculate D_rewards")
            feed_output_token = []
            rollout_output = np.reshape(rollout_output,
                                        [-1, FLAGS.max_dec_steps])
            for sent in rollout_output:
                index_list = np.where(sent == 3)[0]
                if len(index_list) != 0:
                    ind = index_list[0]
                    new_sent = np.concatenate([
                        sent[:ind + 1],
                        np.ones(FLAGS.max_dec_steps - ind - 1)
                    ])
                    feed_output_token.append(new_sent)
                else:
                    feed_output_token.append(sent)
            feed_output_token = np.array(feed_output_token)
            feed_output_token[feed_output_token > FLAGS.vocab_size - 1] = 0

            # update
            ypred_for_auc = []
            for token in np.split(feed_output_token, FLAGS.rollout):
                feed = {
                    discriminator.input_x: token,
                    discriminator.dropout_keep_prob: 1.0
                }
                ypred_auc = sess.run(
                    discriminator.ypred_for_auc, feed
                )  # shape: [rollout_num * seqlen(this is number of roll) * batch_size, 2]
                ypred_for_auc.append(ypred_auc)
            ypred_for_auc = np.concatenate(ypred_for_auc)
            ypred = np.array([item[1] for item in ypred_for_auc])
            ypred = np.reshape(ypred, [FLAGS.rollout, -1, FLAGS.batch_size])
            rewards = np.transpose(np.sum(ypred, 0)) / (
                1.0 * FLAGS.rollout)  # [batch_size, output_max_len// 20]

            if np.std(rewards) != 0.:
                rewards = (rewards - np.mean(rewards)) / np.std(rewards)
            D_rewards = np.zeros([FLAGS.batch_size, FLAGS.max_dec_steps])
            for count, i in enumerate(range(1, FLAGS.max_dec_steps, 10)):
                D_rewards[:, i] = rewards[:, count]
            print("D_rewards:", D_rewards.shape)

            # Train discriminator
            print("Start to train discriminator...")
            for _ in tqdm(range(5)):
                batch = discriminator_batcher.next_batch()
                result = generator.run_summary_token_step(sess, batch)
                output_summary_token = result['output_summary_token']
                output_summary_token = np.transpose(
                    np.squeeze(
                        output_summary_token))  # [batch_size, max_dec_steps]
                ground_truth = batch.target_batch  # [batch_size, max_dec_steps]
                output_summary = []
                for sent in output_summary_token:
                    index_list = np.where(sent == 3)[0]
                    if len(index_list) != 0:
                        ind = index_list[0]
                        new_sent = np.concatenate([
                            sent[:ind + 1],
                            np.ones(FLAGS.max_dec_steps - ind - 1)
                        ])
                        output_summary.append(new_sent)
                    else:
                        output_summary.append(sent)
                output_summary = np.array(output_summary)

                max_epoch = 3
                dis_loader = Dataloader(FLAGS.batch_size, FLAGS.vocab_size)
                pos_train = [ground_truth[i] for i in range(len(ground_truth))]
                neg_train = [
                    output_summary[i] for i in range(len(output_summary))
                ]
                for _ in range(max_epoch):
                    dis_loader.load_data(pos_train, neg_train)
                    dis_loader.reset_pointer()
                    # train for 1 epoch
                    for it in range(dis_loader.num_batch):
                        x_batch, y_batch = dis_loader.next_batch()
                        feed = {
                            discriminator.input_x: x_batch,
                            discriminator.input_y: y_batch,
                            discriminator.dropout_keep_prob: 0.5
                        }
                        sess.run(discriminator.train_op, feed)
コード例 #3
0
def run_training(generator, discriminator, generator_batcher,
                 discriminator_batcher, summary_writer, sess_context_manager):
    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')

    with sess_context_manager as sess:
        D_rewards = np.zeros((FLAGS.batch_size, FLAGS.max_dec_steps))
        rouge_rewards = np.zeros((FLAGS.batch_size, 1))

        while True:
            # Train the generator for one step
            for it in range(1):
                batch = generator_batcher.next_batch()
                batch.batch_reward = D_rewards
                batch.batch_rouge_reward = rouge_rewards

                tf.logging.info('running training step...')
                t0 = time.time()
                result_train = generator.run_train_step(sess, batch)

                t1 = time.time()
                tf.logging.info('seconds for training step: %.3f', t1 - t0)
                loss = result_train['loss']
                tf.logging.info('Generator train loss: %f',
                                loss)  # print the loss to screen

                summaries = result_train['summaries']
                train_step = result_train['global_step']
                summary_writer.add_summary(summaries,
                                           train_step)  # write the summaries

                rg = Rouge()

                gtruth_token = batch.target_batch
                output_sample_token = np.transpose(
                    np.squeeze(result_train['output_sample_token']))
                output_argmax_token = np.transpose(
                    np.squeeze(result_train['output_summary_token']))

                def remove_eos(input_text):

                    _input_text_eos = np.where(input_text == 3)[0]
                    if len(_input_text_eos) != 0:
                        cliped_text = input_text[:_input_text_eos[0]]
                    else:
                        cliped_text = input_text
                    return ' '.join(map(str, cliped_text))

                rouge_rewards = []

                for gt, sample, argmax in zip(gtruth_token,
                                              output_sample_token,
                                              output_argmax_token):
                    _gt = remove_eos(gt)
                    _sample = remove_eos(sample)
                    _argmax = remove_eos(argmax)

                    r_baseline = rg.get_scores(_gt, _argmax)[0]['rouge-l']['f']
                    r_sample = rg.get_scores(_gt, _sample)[0]['rouge-l']['f']
                    rouge_rewards.append(r_baseline - r_sample)

                rouge_rewards = np.reshape(rouge_rewards,
                                           [FLAGS.batch_size, 1])
                tf.logging.info('RL reward for rouge-L: %.3f',
                                np.mean(rouge_rewards))

                tf.logging.info('running rollout step...')
                t0 = time.time()
                result_rollout = generator.run_rollout_step(sess, batch)
                t1 = time.time()
                tf.logging.info('seconds for rollout step: %.3f', t1 - t0)

                # shape [rollout_num, seqlen(this is number of roll), batch_size, seq_len]
                rollout_output = result_rollout['rollout_token']
                given_number_of_rollout = rollout_output.shape[1]

                # calculate D_reward
                print("start to calculate D_rewards")
                _feed_output_token = np.reshape(rollout_output,
                                                [-1, FLAGS.max_dec_steps])

                feed_output_token = []
                for sent in _feed_output_token:
                    index_list = np.where(sent == 3)[0]
                    if len(index_list) != 0:
                        ind = index_list[0]
                        new_sent = np.concatenate(
                            [sent[:ind + 1],
                             np.ones(100 - ind - 1)])
                        feed_output_token.append(new_sent)
                    else:
                        new_sent = np.array(sent, dtype=np.int32)
                        feed_output_token.append(new_sent)

                feed_output_token = np.array(feed_output_token)
                feed_output_token = feed_output_token.reshape(
                    (len(feed_output_token), -1))
                print("feed_out_token.shape:", feed_output_token.shape)
                '''
                clip_index = np.where(feed_output_token > FLAGS.vocab_size-1)
                index_x = clip_index[0]
                index_y = clip_index[1]
                for i in range(len(index_x)):
                    feed_output_token[index_x[i]][index_y[i]] = 0
                '''
                if feed_output_token.shape[1] > 1:
                    for i in range(len(feed_output_token)):
                        clip_index = np.where(
                            np.array(feed_output_token[i]) > FLAGS.vocab_size -
                            1)
                        for idx in clip_index:
                            feed_output_token[i][idx] = 0

                    # update
                    ypred_for_auc = []
                    for feed_output_token_small in np.split(
                            feed_output_token, FLAGS.rollout):
                        feed = {
                            discriminator.input_x: feed_output_token_small,
                            discriminator.dropout_keep_prob: 1.0
                        }
                        # ypred_for_auc: [rollout_num * seqlen(this is number of roll) * batch_size, 2]
                        ypred_for_auc.append(
                            sess.run(discriminator.ypred_for_auc, feed))
                    ypred_for_auc = np.concatenate(ypred_for_auc)
                    ypred = np.array([item[1] for item in ypred_for_auc])
                    framed_yred = np.reshape(ypred, [
                        FLAGS.rollout, given_number_of_rollout,
                        FLAGS.batch_size
                    ])
                    rewards = np.transpose(np.sum(framed_yred, 0)) / (
                        1.0 * FLAGS.rollout
                    )  # [batch_size, output_max_len// 20]
                    if np.std(rewards) != 0.:
                        rewards = (rewards -
                                   np.mean(rewards)) / np.std(rewards)
                    D_rewards = np.zeros(
                        (FLAGS.batch_size, FLAGS.max_dec_steps))
                    print("rewards.shape:", rewards.shape)

                    for count, i in enumerate(
                            range(1, FLAGS.max_dec_steps,
                                  int(FLAGS.max_dec_steps /
                                      rewards.shape[1]))):
                        D_rewards[:, i] = rewards[:, count]

                else:
                    tmp = []
                    for i in range(len(feed_output_token)):
                        tmp.append(feed_output_token[i][0])
                    feed_output_token = np.array(tmp).copy()
                    print("feed-new:", feed_output_token.shape)
                    print("Filter out!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

            # Train the discriminator
            print("Start to train the Discriminator!")
            for _ in tqdm(range(5)):
                batch = discriminator_batcher.next_batch()
                res = generator.run_summary_token_step(sess, batch)
                _output_argmax_summary = res['output_summary_token']
                _output_argmax_summary = np.transpose(
                    np.squeeze(
                        _output_argmax_summary))  # [batch_size, max_dec_steps]
                gtruth_data = batch.target_batch  # [batch_size, max_dec_steps]; format: [[], [], ...]

                output_argmax_summary = []
                for sent in _output_argmax_summary:
                    index_list = np.where(sent == 3)[0]
                    if len(index_list) != 0:
                        ind = index_list[0]
                        new_sent = np.concatenate([
                            sent[:ind + 1],
                            np.ones(FLAGS.max_dec_steps - ind - 1)
                        ])
                        output_argmax_summary.append(new_sent)
                    else:
                        output_argmax_summary.append(sent)
                output_argmax_summary = np.array(output_argmax_summary)

                positive_examples = []
                negative_examples = []
                for ele in gtruth_data:
                    positive_examples.append(ele)
                for ele in output_argmax_summary:
                    negative_examples.append(ele)
                dis_data_loader = Dataloader(FLAGS.batch_size,
                                             FLAGS.vocab_size)

                max_epoch = 3

                for epoch in range(max_epoch):
                    dis_data_loader.load_data(positive_examples,
                                              negative_examples)
                    dis_data_loader.reset_pointer()
                    for it in range(dis_data_loader.num_batch):
                        x_batch, y_batch = dis_data_loader.next_batch()
                        feed = {
                            discriminator.input_x: x_batch,
                            discriminator.input_y: y_batch,
                            discriminator.dropout_keep_prob: 0.5
                        }
                        _ = sess.run(discriminator.train_op, feed)
コード例 #4
0
def pretrain_discriminator(discriminator, sess_context_manager):
    dis_train_data_loader = Dataloader(FLAGS.dis_batch_size, FLAGS.vocab_size)
    dis_test_data_loader = Dataloader(FLAGS.dis_batch_size, FLAGS.vocab_size)

    print("Pre-train Discriminator")

    pretrain_dis_data = np.load(FLAGS.pretrain_dis_data_path)
    pos_summary, neg_summary = pretrain_dis_data[
        'pos_summary_idx'], pretrain_dis_data['neg_summary_idx']
    positive_train_summary = []
    negative_train_summary = []
    positive_eval_summary = []
    negative_eval_summary = []

    ##############################################################################
    #############      Prepare Train and Eval data  ##############################

    for i in range(len(pos_summary)):
        if i < 143800:
            positive_train_summary.append(pos_summary[i][:FLAGS.max_dec_steps])
            negative_train_summary.append(neg_summary[i][:FLAGS.max_dec_steps])
        else:
            positive_eval_summary.append(pos_summary[i][:FLAGS.max_dec_steps])
            negative_eval_summary.append(neg_summary[i][:FLAGS.max_dec_steps])

    ##############################################################################
    #############      Training       ############################################
    train_max_epoch = 15
    sess = sess_context_manager
    for epoch in tqdm(range(train_max_epoch)):
        dis_train_data_loader.load_data(positive_train_summary,
                                        negative_train_summary)
        dis_train_data_loader.reset_pointer()
        for it in range(dis_train_data_loader.num_batch):
            x_batch, y_batch = dis_train_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.input_y: y_batch,
                discriminator.dropout_keep_prob: 0.5
            }
            sess.run(discriminator.train_op, feed)

        dis_test_data_loader.load_data(positive_eval_summary,
                                       negative_eval_summary)
        dis_test_data_loader.reset_pointer()
        acc_list = []

        for it in range(dis_test_data_loader.num_batch):
            x_batch, y_batch = dis_test_data_loader.next_batch()
            feed = {
                discriminator.input_x: x_batch,
                discriminator.input_y: y_batch,
                discriminator.dropout_keep_prob: 1.0
            }
            pred = sess.run(discriminator.predictions, feed)
            target = np.where(
                np.array(y_batch) == 1)[-1]  #np.concatenate(y_batch)
            acc_list.append(accuracy_score(y_pred=pred, y_true=target))
        eval_acc = np.mean(acc_list)

        print('Pretrain epoch:{}, Eval accuracy: {}'.format(epoch, eval_acc))